diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs index 628d955907..66bfbc6076 100644 --- a/src/Microsoft.ML.Core/Data/ColumnType.cs +++ b/src/Microsoft.ML.Core/Data/ColumnType.cs @@ -748,8 +748,7 @@ public override bool Equals(ColumnType other) if (other == this) return true; - var tmp = other as KeyType; - if (tmp == null) + if (!(other is KeyType tmp)) return false; if (RawKind != tmp.RawKind) return false; diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index d3dc3d7de8..74da229bc4 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -332,11 +332,9 @@ internal static void GetSlotNames(RoleMappedSchema schema, RoleMappedSchema.Colu Contracts.CheckValueOrNull(schema); Contracts.CheckParam(vectorSize >= 0, nameof(vectorSize)); - IReadOnlyList list; - if ((list = schema?.GetColumns(role)) == null || list.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize)) - { + IReadOnlyList list = schema?.GetColumns(role); + if (list?.Count != 1 || !schema.Schema[list[0].Index].HasSlotNames(vectorSize)) VBufferUtils.Resize(ref slotNames, vectorSize, 0); - } else schema.Schema[list[0].Index].Metadata.GetValue(Kinds.SlotNames, ref slotNames); } diff --git a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs index 9a3f02cc24..643136b6e0 100644 --- a/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs +++ b/src/Microsoft.ML.Core/Data/RoleMappedSchema.cs @@ -8,69 +8,6 @@ namespace Microsoft.ML.Runtime.Data { - /// - /// This contains information about a column in an . It is essentially a convenience cache - /// containing the name, column index, and column type for the column. The intended usage is that users of - /// will have a convenient method of getting the index and type without having to separately query it through the , - /// since practically the first thing a consumer of a will want to do once they get a mappping is get - /// the type and index of the corresponding column. - /// - public sealed class ColumnInfo - { - public readonly string Name; - public readonly int Index; - public readonly ColumnType Type; - - private ColumnInfo(string name, int index, ColumnType type) - { - Name = name; - Index = index; - Type = type; - } - - /// - /// Create a ColumnInfo for the column with the given name in the given schema. Throws if the name - /// doesn't map to a column. - /// - public static ColumnInfo CreateFromName(Schema schema, string name, string descName) - { - if (!TryCreateFromName(schema, name, out var colInfo)) - throw Contracts.ExceptParam(nameof(name), $"{descName} column '{name}' not found"); - - return colInfo; - } - - /// - /// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns - /// false if the name doesn't map to a column. - /// - public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo colInfo) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckNonEmpty(name, nameof(name)); - - colInfo = null; - if (!schema.TryGetColumnIndex(name, out int index)) - return false; - - colInfo = new ColumnInfo(name, index, schema[index].Type); - return true; - } - - /// - /// Creates a ColumnInfo for the column with the given column index. Note that the name - /// of the column might actually map to a different column, so this should be used with care - /// and rarely. - /// - public static ColumnInfo CreateFromIndex(Schema schema, int index) - { - Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(0 <= index && index < schema.Count, nameof(index)); - - return new ColumnInfo(schema[index].Name, index, schema[index].Type); - } - } - /// /// Encapsulates an plus column role mapping information. The purpose of role mappings is to /// provide information on what the intended usage is for. That is: while a given data view may have a column named @@ -192,32 +129,32 @@ public static KeyValuePair CreatePair(ColumnRole role, strin /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Feature { get; } + public Schema.Column? Feature { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Label { get; } + public Schema.Column? Label { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Group { get; } + public Schema.Column? Group { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Weight { get; } + public Schema.Column? Weight { get; } /// /// The column, when there is exactly one (null otherwise). /// - public ColumnInfo Name { get; } + public Schema.Column? Name { get; } // Maps from role to the associated column infos. - private readonly Dictionary> _map; + private readonly Dictionary> _map; - private RoleMappedSchema(Schema schema, Dictionary> map) + private RoleMappedSchema(Schema schema, Dictionary> map) { Contracts.AssertValue(schema); Contracts.AssertValue(map); @@ -256,42 +193,40 @@ private RoleMappedSchema(Schema schema, Dictionary> map) + private RoleMappedSchema(Schema schema, Dictionary> map) : this(schema, Copy(map)) { } - private static void Add(Dictionary> map, ColumnRole role, ColumnInfo info) + private static void Add(Dictionary> map, ColumnRole role, Schema.Column column) { Contracts.AssertValue(map); Contracts.AssertNonEmpty(role.Value); - Contracts.AssertValue(info); if (!map.TryGetValue(role.Value, out var list)) { - list = new List(); + list = new List(); map.Add(role.Value, list); } - list.Add(info); + list.Add(column); } - private static Dictionary> MapFromNames(Schema schema, IEnumerable> roles, bool opt = false) + private static Dictionary> MapFromNames(Schema schema, IEnumerable> roles, bool opt = false) { Contracts.AssertValue(schema); Contracts.AssertValue(roles); - var map = new Dictionary>(); + var map = new Dictionary>(); foreach (var kvp in roles) { Contracts.AssertNonEmpty(kvp.Key.Value); if (string.IsNullOrEmpty(kvp.Value)) continue; - ColumnInfo info; - if (!opt) - info = ColumnInfo.CreateFromName(schema, kvp.Value, kvp.Key.Value); - else if (!ColumnInfo.TryCreateFromName(schema, kvp.Value, out info)) - continue; - Add(map, kvp.Key.Value, info); + var info = schema.GetColumnOrNull(kvp.Value); + if (info.HasValue) + Add(map, kvp.Key.Value, info.Value); + else if (!opt) + throw Contracts.ExceptParam(nameof(schema), $"{kvp.Value} column '{kvp.Key.Value}' not found"); } return map; } @@ -318,18 +253,18 @@ public bool HasMultiple(ColumnRole role) /// If there are columns of the given role, this returns the infos as a readonly list. Otherwise, /// it returns null. /// - public IReadOnlyList GetColumns(ColumnRole role) + public IReadOnlyList GetColumns(ColumnRole role) => _map.TryGetValue(role.Value, out var list) ? list : null; /// /// An enumerable over all role-column associations within this object. /// - public IEnumerable> GetColumnRoles() + public IEnumerable> GetColumnRoles() { foreach (var roleAndList in _map) { foreach (var info in roleAndList.Value) - yield return new KeyValuePair(roleAndList.Key, info); + yield return new KeyValuePair(roleAndList.Key, info); } } @@ -359,13 +294,13 @@ public IEnumerable> GetColumnRoleNames(ColumnRo } /// - /// Returns the corresponding to if there is + /// Returns the corresponding to if there is /// exactly one such mapping, and otherwise throws an exception. /// /// The role to look up - /// The info corresponding to that role, assuming there was only one column + /// The column corresponding to that role, assuming there was only one column /// mapped to that - public ColumnInfo GetUniqueColumn(ColumnRole role) + public Schema.Column GetUniqueColumn(ColumnRole role) { var infos = GetColumns(role); if (Utils.Size(infos) != 1) @@ -373,9 +308,9 @@ public ColumnInfo GetUniqueColumn(ColumnRole role) return infos[0]; } - private static Dictionary> Copy(Dictionary> map) + private static Dictionary> Copy(Dictionary> map) { - var copy = new Dictionary>(map.Count); + var copy = new Dictionary>(map.Count); foreach (var kvp in map) { Contracts.Assert(Utils.Size(kvp.Value) > 0); diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 95be4d8f69..96cc366397 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -8,7 +8,6 @@ using System.Linq; using System.Reflection; using Microsoft.ML.Data; -using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; @@ -36,7 +35,7 @@ internal sealed class Transposer : ITransposeDataView, IDisposable public readonly int RowCount; // -1 for input columns that were not transposed, a non-negative index into _cols for those that were. private readonly int[] _inputToTransposed; - private readonly ColumnInfo[] _cols; + private readonly Schema.Column[] _cols; private readonly int[] _splitLim; private readonly SchemaImpl _tschema; private bool _disposed; @@ -104,13 +103,13 @@ private Transposer(IHost host, IDataView view, bool forceSave, int[] columns) columnSet = columnSet.Where(c => ttschema.GetSlotType(c) == null); } columns = columnSet.ToArray(); - _cols = new ColumnInfo[columns.Length]; + _cols = new Schema.Column[columns.Length]; var schema = _view.Schema; _nameToICol = new Dictionary(); _inputToTransposed = Utils.CreateArray(schema.Count, -1); for (int c = 0; c < columns.Length; ++c) { - _nameToICol[(_cols[c] = ColumnInfo.CreateFromIndex(schema, columns[c])).Name] = c; + _nameToICol[(_cols[c] = schema[columns[c]]).Name] = c; _inputToTransposed[columns[c]] = c; } @@ -305,7 +304,7 @@ public SchemaImpl(Transposer parent) _slotTypes = new VectorType[_parent._cols.Length]; for (int c = 0; c < _slotTypes.Length; ++c) { - ColumnInfo srcInfo = _parent._cols[c]; + var srcInfo = _parent._cols[c]; var ctype = srcInfo.Type.ItemType; var primitiveType = ctype as PrimitiveType; _ectx.Assert(primitiveType != null); diff --git a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs index 8364b17060..38867fe005 100644 --- a/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs +++ b/src/Microsoft.ML.Data/Depricated/Instances/HeaderSchema.cs @@ -185,13 +185,14 @@ public static FeatureNameCollection Create(RoleMappedSchema schema) { // REVIEW: This shim should be deleted as soon as is convenient. Contracts.CheckValue(schema, nameof(schema)); - Contracts.CheckParam(schema.Feature != null, nameof(schema), "Cannot create feature name collection if we have no features"); - Contracts.CheckParam(schema.Feature.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size"); + Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Cannot create feature name collection if we have no features"); + var featureCol = schema.Feature.Value; + Contracts.CheckParam(schema.Feature.Value.Type.ValueCount > 0, nameof(schema), "Cannot create feature name collection if our features are not of known size"); VBuffer> slotNames = default; - int len = schema.Feature.Type.ValueCount; - if (schema.Schema[schema.Feature.Index].HasSlotNames(len)) - schema.Schema[schema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); + int len = featureCol.Type.ValueCount; + if (featureCol.HasSlotNames(len)) + featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); else slotNames = VBufferUtils.CreateEmpty>(len); var slotNameValues = slotNames.GetValues(); diff --git a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs index 987246eea5..49802efb39 100644 --- a/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs +++ b/src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs @@ -125,12 +125,11 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType lab labelType = null; if (trainRms.Label != null) { - labelType = trainRms.Label.Type; - if (labelType.IsKey && - trainRms.Schema[trainRms.Label.Index].HasKeyValues(labelType.KeyCount)) + labelType = trainRms.Label.Value.Type; + if (labelType is KeyType && trainRms.Label.Value.HasKeyValues(labelType.KeyCount)) { VBuffer> keyValues = default; - trainRms.Schema[trainRms.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); + trainRms.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyValues); return keyValues.DenseValues().Select(v => v.ToString()).ToArray(); } } diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 905bdaaa13..fff79ebf75 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -97,15 +97,15 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) var t = score.Type; if (t != NumberType.Float) throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t).MarkSensitive(MessageSensitivity.Schema); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + Host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.Float && t.KeyCount != 2) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Name, t).MarkSensitive(MessageSensitivity.Schema); + throw Host.Except("Label column '{0}' has type '{1}' but must be R4 or a 2-value key", schema.Label.Value.Name, t).MarkSensitive(MessageSensitivity.Schema); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { - return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName); + return new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Value.Index, stratName); } internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data) @@ -501,11 +501,11 @@ private void FinishOtherMetrics() internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(!_streaming && PassNum < 2 || PassNum < 1); - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); @@ -745,13 +745,13 @@ private protected override IDataView GetOverallResultsCore(IDataView overall) private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column"); // The anomaly detection evaluator outputs the label and the score. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.AnomalyDetection); - yield return scoreInfo.Name; + yield return scoreCol.Name; // No additional output columns. } diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 4468ba000d..7fed38a710 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -129,11 +129,11 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) var host = Host.SchemaSensitive(); var t = score.Type; if (t.IsVector || t.ItemType != NumberType.Float) - throw host.SchemaSensitive().Except("Score column '{0}' has type '{1}' but must be R4", score, t); - host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4", t.ToString()); + host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4 && t != NumberType.R8 && t != BoolType.Instance && t.KeyCount != 2) - throw host.SchemaSensitive().Except("Label column '{0}' has type '{1}' but must be R4, R8, BL or a 2-value key", schema.Label.Name, t); + throw host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4, R8, BL or a 2-value key", t.ToString()); } private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) @@ -142,14 +142,14 @@ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema sche var host = Host.SchemaSensitive(); if (prob != null) { - host.Check(prob.Count == 1, "Cannot have multiple probability columns"); + host.CheckParam(prob.Count == 1, nameof(schema), "Cannot have multiple probability columns"); var probType = prob[0].Type; if (probType != NumberType.Float) - throw host.SchemaSensitive().Except("Probability column '{0}' has type '{1}' but must be R4", prob[0].Name, probType); + throw host.ExceptSchemaMismatch(nameof(schema), "probability", prob[0].Name, "R4", probType.ToString()); } else if (!_useRaw) { - throw host.Except( + throw host.ExceptParam(nameof(schema), "Cannot compute the predicted label from the probability column because it does not exist"); } } @@ -172,13 +172,13 @@ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) { // Get the label names if they exist, or use the default names. - ColumnType type; var labelNames = default(VBuffer>); - if (schema.Label.Type.IsKey && - (type = schema.Schema[schema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type) != null && - type.ItemType.IsKnownSizeVector && type.ItemType.IsText) + var labelCol = schema.Label.Value; + if (labelCol.Type is KeyType && + labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type is VectorType vecType && + vecType.Size > 0 && vecType.ItemType == TextType.Instance) { - schema.Schema[schema.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); + labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); } else labelNames = new VBuffer>(2, new[] { "positive".AsMemory(), "negative".AsMemory() }); @@ -193,11 +193,10 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Contracts.AssertValue(scoreInfo); var probInfos = schema.GetColumns(MetadataUtils.Const.ScoreValueKind.Probability); var probCol = Utils.Size(probInfos) > 0 ? probInfos[0].Name : null; - return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Name, _threshold, _useRaw); + return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Value.Name, _threshold, _useRaw); } public override IEnumerable GetOverallMetricColumns() @@ -611,12 +610,12 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, bool internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); Host.Assert(PassNum < 1); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); @@ -631,7 +630,7 @@ internal override void InitializeNextPass(Row row, RoleMappedSchema schema) Host.Assert((schema.Weight != null) == Weighted); if (Weighted) - _weightGetter = row.GetGetter(schema.Weight.Index); + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -1176,14 +1175,14 @@ public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args) { var cols = base.GetInputColumnRolesCore(schema); - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.BinaryClassification); // Get the optional probability column. - var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), - scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float); - if (probInfo != null) - cols = MetadataUtils.Prepend(cols, RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name)); + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); + if (probCol.HasValue) + cols = MetadataUtils.Prepend(cols, RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name)); return cols; } @@ -1482,19 +1481,19 @@ private void SavePrPlots(List prList) private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // The binary classifier evaluator outputs the label, score and probability columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.BinaryClassification); - yield return scoreInfo.Name; - var probInfo = EvaluateUtils.GetOptAuxScoreColumnInfo(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), - scoreInfo.Index, MetadataUtils.Const.ScoreValueKind.Probability, t => t == NumberType.Float); + yield return scoreCol.Name; + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); // Return the output columns. The LogLoss column is returned only if the probability column exists. - if (probInfo != null) + if (probCol.HasValue) { - yield return probInfo.Name; + yield return probCol.Value.Name; yield return BinaryPerInstanceEvaluator.LogLoss; } diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index fa60665447..79d8ed42b7 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -96,30 +96,29 @@ public ClusteringMetrics Evaluate(IDataView data, string score, string label = n private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { - ColumnType type; - if (schema.Label != null && (type = schema.Label.Type) != NumberType.Float && type.KeyCount == 0) + ColumnType type = schema.Label?.Type; + if (type != null && type != NumberType.Float && !(type is KeyType keyType && keyType.Count > 0)) { - throw Host.Except("Clustering evaluator: label column '{0}' type must be {1} or Key of known cardinality." + - " Provide a correct label column, or none: it is optional.", - schema.Label.Name, NumberType.Float); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, + "R4 or key of known cardinality", type.ToString()); } var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); type = score.Type; if (!type.IsKnownSizeVector || type.ItemType != NumberType.Float) - throw Host.Except("Scores column '{0}' type must be a float vector of known size", score.Name); + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4 vector of known size", type.ToString()); } private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { if (_calculateDbi) { - Host.AssertValue(schema.Feature); - var t = schema.Feature.Type; + Host.Assert(schema.Feature.HasValue); + var t = schema.Feature.Value.Type; if (!t.IsKnownSizeVector || t.ItemType != NumberType.Float) { - throw Host.Except("Features column '{0}' type must be {1} vector of known-size", - schema.Feature.Name, NumberType.Float); + throw Host.ExceptSchemaMismatch(nameof(schema), "features", schema.Feature.Value.Name, + "R4 vector of known size", t.ToString()); } } } @@ -129,13 +128,13 @@ private protected override Func GetActiveColsCore(RoleMappedSchema sc var pred = base.GetActiveColsCore(schema); // We also need the features column for dbi calculation. Host.Assert(!_calculateDbi || schema.Feature != null); - return i => _calculateDbi && i == schema.Feature.Index || pred(i); + return i => _calculateDbi && i == schema.Feature.Value.Index || pred(i); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) { Host.AssertValue(schema); - Host.Assert(!_calculateDbi || (schema.Feature != null && schema.Feature.Type.IsKnownSizeVector)); + Host.Assert(!_calculateDbi || schema.Feature?.Type.IsKnownSizeVector == true); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize > 0); int numClusters = score.Type.VectorSize; @@ -316,7 +315,7 @@ public Double Dbi } } - public Counters(int numClusters, bool calculateDbi, ColumnInfo features) + public Counters(int numClusters, bool calculateDbi, Schema.Column? features) { _numClusters = numClusters; CalculateDbi = calculateDbi; @@ -326,10 +325,10 @@ public Counters(int numClusters, bool calculateDbi, ColumnInfo features) _confusionMatrix = new List(); if (CalculateDbi) { - Contracts.AssertValue(features); + Contracts.Assert(features.HasValue); _clusterCentroids = new VBuffer[_numClusters]; for (int i = 0; i < _numClusters; i++) - _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Type.VectorSize); + _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Value.Type.VectorSize); _distancesToCentroids = new Double[_numClusters]; } } @@ -396,7 +395,7 @@ public void UpdateSecondPass(in VBuffer features, int[] indices) private readonly bool _calculateDbi; - public Aggregator(IHostEnvironment env, ColumnInfo features, int scoreVectorSize, bool calculateDbi, bool weighted, string stratName) + internal Aggregator(IHostEnvironment env, Schema.Column? features, int scoreVectorSize, bool calculateDbi, bool weighted, string stratName) : base(env, stratName) { _calculateDbi = calculateDbi; @@ -407,10 +406,10 @@ public Aggregator(IHostEnvironment env, ColumnInfo features, int scoreVectorSize WeightedCounters = Weighted ? new Counters(scoreVectorSize, _calculateDbi, features) : null; if (_calculateDbi) { - Host.AssertValue(features); + Host.Assert(features.HasValue); _clusterCentroids = new VBuffer[scoreVectorSize]; for (int i = 0; i < scoreVectorSize; i++) - _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Type.VectorSize); + _clusterCentroids[i] = VBufferUtils.CreateEmpty(features.Value.Type.VectorSize); } } @@ -493,8 +492,8 @@ internal override void InitializeNextPass(Row row, RoleMappedSchema schema) if (_calculateDbi) { - Host.AssertValue(schema.Feature); - _featGetter = row.GetGetter>(schema.Feature.Index); + Host.Assert(schema.Feature.HasValue); + _featGetter = row.GetGetter>(schema.Feature.Value.Index); } var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize == _scoresArr.Length); @@ -502,12 +501,12 @@ internal override void InitializeNextPass(Row row, RoleMappedSchema schema) if (PassNum == 0) { - if (schema.Label != null) - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + if (schema.Label.HasValue) + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); else _labelGetter = (ref Single value) => value = Single.NaN; - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } else { @@ -821,8 +820,8 @@ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleM Host.CheckValue(schema, nameof(schema)); // Output the label column if it exists. - if (schema.Label != null) - yield return schema.Label.Name; + if (schema.Label.HasValue) + yield return schema.Label.Value.Name; // Return the output columns. yield return ClusteringPerInstanceEvaluator.ClusterId; diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs index 287d882645..5a897a8793 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorBase.cs @@ -49,8 +49,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) private protected void CheckColumnTypes(RoleMappedSchema schema) { // Check the weight column type. - if (schema.Weight != null) - EvaluateUtils.CheckWeightType(Host, schema.Weight.Type); + if (schema.Weight.HasValue) + EvaluateUtils.CheckWeightType(Host, schema.Weight.Value.Type); CheckScoreAndLabelTypes(schema); // Check the other column types. CheckCustomColumnTypesCore(schema); @@ -92,8 +92,8 @@ private Func GetActiveCols(RoleMappedSchema schema) private protected virtual Func GetActiveColsCore(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - var label = schema.Label == null ? -1 : schema.Label.Index; - var weight = schema.Weight == null ? -1 : schema.Weight.Index; + int label = schema.Label?.Index ?? -1; + int weight = schema.Weight?.Index ?? -1; return i => i == score.Index || i == label || i == weight; } diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index a7da5b2efb..095ad1bb5c 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -101,11 +101,11 @@ private static bool CheckScoreColumnKind(Schema schema, int col) } /// - /// Find the score column to use. If name is specified, that is used. Otherwise, this searches for the - /// most recent score set of the given kind. If there is no such score set and defName is specifed it - /// uses defName. Otherwise, it throws. + /// Find the score column to use. If is specified, that is used. Otherwise, this searches + /// for the most recent score set of the given . If there is no such score set and + /// is specifed it uses . Otherwise, it throws. /// - public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, string kind, + public static Schema.Column GetScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName, string kind, string valueKind = MetadataUtils.Const.ScoreValueKind.Score, string defName = null) { Contracts.CheckValueOrNull(ectx); @@ -115,39 +115,40 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schem ectx.CheckNonEmpty(kind, nameof(kind)); ectx.CheckNonEmpty(valueKind, nameof(valueKind)); - int colTmp; - ColumnInfo info; if (!string.IsNullOrWhiteSpace(name)) { -#pragma warning disable MSML_ContractsNameUsesNameof - if (!ColumnInfo.TryCreateFromName(schema, name, out info)) +#pragma warning disable MSML_ContractsNameUsesNameof // This utility method is meant to reflect the argument name of whatever is calling it, so we take that as a parameter, rather than using nameof directly as in most cases. + var col = schema.GetColumnOrNull(name); + if (!col.HasValue) throw ectx.ExceptUserArg(argName, "Score column is missing"); #pragma warning restore MSML_ContractsNameUsesNameof - return info; + return col.Value; } - var maxSetNum = schema.GetMaxMetadataKind(out colTmp, MetadataUtils.Kinds.ScoreColumnSetId, + var maxSetNum = schema.GetMaxMetadataKind(out int colTmp, MetadataUtils.Kinds.ScoreColumnSetId, (s, c) => IsScoreColumnKind(ectx, s, c, kind)); ReadOnlyMemory tmp = default; - foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum)) + foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, maxSetNum)) { + var col = schema[colIdx]; #if DEBUG - schema[col].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnKind, ref tmp); ectx.Assert(ReadOnlyMemoryUtils.EqualsStr(kind, tmp)); #endif // REVIEW: What should this do about hidden columns? Currently we ignore them. - if (schema[col].IsHidden) + if (col.IsHidden) continue; - if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) && - ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance) { - return ColumnInfo.CreateFromIndex(schema, col); + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp); + if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + return col; } } - if (!string.IsNullOrWhiteSpace(defName) && ColumnInfo.TryCreateFromName(schema, defName, out info)) - return info; + if (!string.IsNullOrWhiteSpace(defName) && schema.GetColumnOrNull(defName) is Schema.Column defCol) + return defCol; #pragma warning disable MSML_ContractsNameUsesNameof throw ectx.ExceptUserArg(argName, "Score column is missing"); @@ -155,11 +156,11 @@ public static ColumnInfo GetScoreColumnInfo(IExceptionContext ectx, Schema schem } /// - /// Find the optional auxilliary score column to use. If name is specified, that is used. - /// Otherwise, if colScore is part of a score set, this looks in the score set for a column - /// with the given valueKind. If none is found, it returns null. + /// Find the optional auxilliary score column to use. If is specified, that is used. + /// Otherwise, if is part of a score set, this looks in the score set for a column + /// with the given . If none is found, it returns . /// - public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema schema, string name, string argName, + public static Schema.Column? GetOptAuxScoreColumn(IExceptionContext ectx, Schema schema, string name, string argName, int colScore, string valueKind, Func testType) { Contracts.CheckValueOrNull(ectx); @@ -171,14 +172,14 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema if (!string.IsNullOrWhiteSpace(name)) { - ColumnInfo info; #pragma warning disable MSML_ContractsNameUsesNameof - if (!ColumnInfo.TryCreateFromName(schema, name, out info)) + var col = schema.GetColumnOrNull(name); + if (!col.HasValue) throw ectx.ExceptUserArg(argName, "{0} column is missing", valueKind); - if (!testType(info.Type)) + if (!testType(col.Value.Type)) throw ectx.ExceptUserArg(argName, "{0} column has incompatible type", valueKind); #pragma warning restore MSML_ContractsNameUsesNameof - return info; + return col.Value; } // Get the score column set id from colScore. @@ -192,17 +193,18 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, Schema schema[colScore].Metadata.GetValue(MetadataUtils.Kinds.ScoreColumnSetId, ref setId); ReadOnlyMemory tmp = default; - foreach (var col in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId)) + foreach (var colIdx in schema.GetColumnSet(MetadataUtils.Kinds.ScoreColumnSetId, setId)) { // REVIEW: What should this do about hidden columns? Currently we ignore them. - if (schema[col].IsHidden) + var col = schema[colIdx]; + if (col.IsHidden) continue; - if (schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, col, ref tmp) && - ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp)) + + if (col.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.ScoreValueKind)?.Type == TextType.Instance) { - var res = ColumnInfo.CreateFromIndex(schema, col); - if (testType(res.Type)) - return res; + col.Metadata.GetValue(MetadataUtils.Kinds.ScoreValueKind, ref tmp); + if (ReadOnlyMemoryUtils.EqualsStr(valueKind, tmp) && testType(col.Type)) + return col; } } @@ -226,20 +228,17 @@ private static bool IsScoreColumnKind(IExceptionContext ectx, Schema schema, int } /// - /// If str is non-empty, returns it. Otherwise if info is non-null, returns info.Name. - /// Otherwise, returns def. + /// If is non-empty, returns it. Otherwise if is non-, + /// returns its . Otherwise, returns . /// - public static string GetColName(string str, ColumnInfo info, string def) + public static string GetColName(string str, Schema.Column? info, string def) { Contracts.CheckValueOrNull(str); - Contracts.CheckValueOrNull(info); Contracts.CheckValueOrNull(def); if (!string.IsNullOrEmpty(str)) return str; - if (info != null) - return info.Name; - return def; + return info?.Name ?? def; } public static void CheckWeightType(IExceptionContext ectx, ColumnType type) diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index 84221e8a6a..3c8e3d9c82 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -119,8 +119,8 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) ? Enumerable.Empty>() : StratCols.Select(col => RoleMappedSchema.CreatePair(Strat, col)); - if (needName && schema.Name != null) - roles = MetadataUtils.Prepend(roles, RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Name)); + if (needName && schema.Name.HasValue) + roles = MetadataUtils.Prepend(roles, RoleMappedSchema.ColumnRole.Name.Bind(schema.Name.Value.Name)); return roles.Concat(GetInputColumnRolesCore(schema)); } @@ -134,9 +134,9 @@ Dictionary IEvaluator.Evaluate(RoleMappedData data) private protected virtual IEnumerable> GetInputColumnRolesCore(RoleMappedSchema schema) { // Get the score column information. - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(ArgumentsBase.ScoreColumn), ScoreColumnKind); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); // Get the label column information. string label = EvaluateUtils.GetColName(LabelCol, schema.Label, DefaultColumnNames.Label); @@ -239,23 +239,23 @@ private IDataView WrapPerInstance(RoleMappedData perInst) colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex); // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. - if (perInst.Schema.Name == null) + if (perInst.Schema.Name?.Name is string nameName) { - var args = new GenerateNumberTransform.Arguments(); - args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; - args.UseCounter = true; - idv = new GenerateNumberTransform(Host, args, idv); + cols.Add((nameName, "Instance")); colsToKeep.Add("Instance"); } else { - cols.Add((perInst.Schema.Name.Name, "Instance")); + var args = new GenerateNumberTransform.Arguments(); + args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; + args.UseCounter = true; + idv = new GenerateNumberTransform(Host, args, idv); colsToKeep.Add("Instance"); } // Maml outputs the weight column if it exists. - if (perInst.Schema.Weight != null) - colsToKeep.Add(perInst.Schema.Weight.Name); + if (perInst.Schema.Weight?.Name is string weightName) + colsToKeep.Add(weightName); // Get the other columns from the evaluator. foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema)) diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index de7be6ec84..8c47d9a790 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -77,11 +77,11 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize < 2 || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type {1} but must be a vector of two or more items of type R4", score.Name, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "vector of two or more items of type R4", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.Float && t.KeyCount <= 0) - throw Host.Except("Label column '{0}' has type {1} but must be a float or a known-cardinality key", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "float or a known-cardinality key", t.ToString()); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) @@ -119,10 +119,10 @@ private ReadOnlyMemory[] GetClassNames(RoleMappedSchema schema) private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int numClasses = scoreInfo.Type.VectorSize; - return new MultiClassPerInstanceEvaluator(Host, schema.Schema, scoreInfo, schema.Label.Name); + return new MultiClassPerInstanceEvaluator(Host, schema.Schema, scoreInfo, schema.Label.Value.Name); } public override IEnumerable GetOverallMetricColumns() @@ -390,17 +390,17 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory[] classNames, int s internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Host.Assert(PassNum < 1); - Host.AssertValue(schema.Label); + Host.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); Host.Assert(score.Type.VectorSize == _scoresArr.Length); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter>(score.Index); Host.AssertValue(_labelGetter); Host.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -567,15 +567,15 @@ private static VersionInfo GetVersionInfo() private readonly ReadOnlyMemory[] _classNames; private readonly ColumnType[] _types; - public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, ColumnInfo scoreInfo, string labelCol) - : base(env, schema, Contracts.CheckRef(scoreInfo, nameof(scoreInfo)).Name, labelCol) + public MultiClassPerInstanceEvaluator(IHostEnvironment env, Schema schema, Schema.Column scoreColumn, string labelCol) + : base(env, schema, scoreColumn.Name, labelCol) { CheckInputColumnTypes(schema); - _numClasses = scoreInfo.Type.VectorSize; + _numClasses = scoreColumn.Type.VectorSize; _types = new ColumnType[4]; - if (schema[(int) ScoreIndex].HasSlotNames(_numClasses)) + if (schema[ScoreIndex].HasSlotNames(_numClasses)) { var classNames = default(VBuffer>); schema[(int) ScoreIndex].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref classNames); @@ -981,10 +981,10 @@ public override IEnumerable GetOverallMetricColumns() private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // Output the label column. - yield return schema.Label.Name; + yield return schema.Label.Value.Name; // Return the output columns. yield return MultiClassPerInstanceEvaluator.Assigned; @@ -998,13 +998,14 @@ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst { // If the label column is a key without key values, convert it to I8, just for saving the per-instance // text file, since if there are different key counts the columns cannot be appended. - if (!perInst.Schema.TryGetColumnIndex(schema.Label.Name, out int labelCol)) - throw Host.Except("Could not find column '{0}'", schema.Label.Name); + string labelName = schema.Label.Value.Name; + if (!perInst.Schema.TryGetColumnIndex(labelName, out int labelCol)) + throw Host.Except("Could not find column '{0}'", labelName); var labelType = perInst.Schema[labelCol].Type; - if (labelType is KeyType keyType && (!(bool) perInst.Schema[labelCol].HasKeyValues(keyType.KeyCount) || labelType.RawKind != DataKind.U4)) + if (labelType is KeyType keyType && (!(bool)perInst.Schema[labelCol].HasKeyValues(keyType.KeyCount) || labelType.RawKind != DataKind.U4)) { - perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name, - schema.Label.Name, perInst.Schema[labelCol].Type, NumberType.R8, + perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, labelName, + labelName, perInst.Schema[labelCol].Type, NumberType.R8, (in uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)keyType.Min); } @@ -1022,7 +1023,7 @@ private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst { var type = perInst.Schema[sortedScoresIndex].Type; if (_numTopClasses < type.VectorSize) - perInst = new SlotsDroppingTransformer(Host, MultiClassPerInstanceEvaluator.SortedScores, min: _numTopClasses).Transform(perInst); + perInst = new SlotsDroppingTransformer(Host, MultiClassPerInstanceEvaluator.SortedScores, min: _numTopClasses).Transform(perInst); } return perInst; } diff --git a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs index cbff742597..0b383e8c61 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiOutputRegressionEvaluator.cs @@ -49,11 +49,10 @@ public MultiOutputRegressionEvaluator(IHostEnvironment env, Arguments args) private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); - var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Host.AssertValue(scoreInfo); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + var scoreCol = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name); + return new MultiOutputRegressionPerInstanceEvaluator(Host, schema.Schema, scoreCol.Name, schema.Label.Value.Name); } private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) @@ -61,11 +60,11 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize == 0 || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", score.Name, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "known size vector of R4", t.ToString()); + Host.Check(schema.Label.HasValue, "Could not find the label column"); + t = schema.Label.Value.Type; if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) - throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "known size vector of R4 or R8", t.ToString()); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) @@ -302,17 +301,17 @@ public Aggregator(IHostEnvironment env, IRegressionLoss lossFunction, int size, internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); + Contracts.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetVecGetterAs(NumberType.Float, row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetVecGetterAs(NumberType.Float, row, schema.Label.Value.Index); _scoreGetter = row.GetGetter>(score.Index); Contracts.AssertValue(_labelGetter); Contracts.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() @@ -644,11 +643,11 @@ private protected override IEnumerable GetPerInstanceColumnsToSave(RoleM // The multi output regression evaluator outputs the label and score column if requested by the user. if (!_supressScoresAndLabels) { - yield return schema.Label.Name; + yield return schema.Label.Value.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.MultiOutputRegression); - yield return scoreInfo.Name; + yield return scoreCol.Name; } // Return the output columns. diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index e7f342d1a1..bfeb0e5ebf 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -41,7 +41,7 @@ public QuantileRegressionEvaluator(IHostEnvironment env, Arguments args) private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); int scoreSize = scoreInfo.Type.VectorSize; var type = schema.Schema[scoreInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type; @@ -50,7 +50,7 @@ private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchem schema.Schema[scoreInfo.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref quantiles); Host.Assert(quantiles.IsDense && quantiles.Length == scoreSize); - return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name, scoreSize, quantiles); + return new QuantileRegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name, scoreSize, quantiles); } private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) @@ -58,14 +58,11 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; if (t.VectorSize == 0 || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8)) - { - throw Host.Except( - "Score column '{0}' has type '{1}' but must be a known length vector of type R4 or R8", score.Name, t); - } - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "vector of type R4 or R8", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4", t.ToString()); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) @@ -541,13 +538,13 @@ public override IEnumerable GetOverallMetricColumns() private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Must contain a label column"); // The quantile regression evaluator outputs the label and score columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.QuantileRegression); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RegressionPerInstanceEvaluator.L1; diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 5dc4a6dff4..b9cec5d2d6 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -86,28 +86,27 @@ public RankerEvaluator(IHostEnvironment env, Arguments args) private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { - var t = schema.Label.Type; - if (t != NumberType.Float && !t.IsKey) + var t = schema.Label.Value.Type; + if (t != NumberType.Float && !(t is KeyType)) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.LabelColumn), "Label column '{0}' has type '{1}' but must be R4 or a key", - schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.LabelColumn), + "label", schema.Label.Value.Name, "R4 or a key", t.ToString()); } - var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - if (scoreInfo.Type != NumberType.Float) + var scoreCol = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); + if (scoreCol.Type != NumberType.Float) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.ScoreColumn), "Score column '{0}' has type '{1}' but must be R4", - scoreInfo.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.ScoreColumn), + "score", scoreCol.Name, "R4", t.ToString()); } } private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema) { - var t = schema.Group.Type; - if (!t.IsKey) + var t = schema.Group.Value.Type; + if (!(t is KeyType)) { - throw Host.ExceptUserArg(nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), - "Group column '{0}' has type '{1}' but must be a key", - schema.Group.Name, t); + throw Host.ExceptSchemaMismatch(nameof(RankerMamlEvaluator.Arguments.GroupIdColumn), + "group", schema.Group.Value.Name, "key", t.ToString()); } } @@ -115,7 +114,7 @@ private protected override void CheckCustomColumnTypesCore(RoleMappedSchema sche private protected override Func GetActiveColsCore(RoleMappedSchema schema) { var pred = base.GetActiveColsCore(schema); - return i => i == schema.Group.Index || pred(i); + return i => i == schema.Group.Value.Index || pred(i); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) @@ -126,12 +125,12 @@ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, internal override IDataTransform GetPerInstanceMetricsCore(RoleMappedData data) { Host.CheckValue(data, nameof(data)); - Host.CheckParam(data.Schema.Label != null, nameof(data), "Schema must contain a label column"); + Host.CheckParam(data.Schema.Label.HasValue, nameof(data), "Schema must contain a label column"); var scoreInfo = data.Schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Host.CheckParam(data.Schema.Group != null, nameof(data), "Schema must contain a group column"); + Host.CheckParam(data.Schema.Group.HasValue, nameof(data), "Schema must contain a group column"); return new RankerPerInstanceTransform(Host, data.Data, - data.Schema.Label.Name, scoreInfo.Name, data.Schema.Group.Name, _truncationLevel, _labelGains); + data.Schema.Label.Value.Name, scoreInfo.Name, data.Schema.Group.Value.Name, _truncationLevel, _labelGains); } public override IEnumerable GetOverallMetricColumns() @@ -443,20 +442,20 @@ public Aggregator(IHostEnvironment env, Double[] labelGains, int truncationLevel internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); - Contracts.AssertValue(schema.Group); + Contracts.Assert(schema.Label.HasValue); + Contracts.Assert(schema.Group.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); - _newGroupDel = RowCursorUtils.GetIsNewGroupDelegate(row, schema.Group.Index); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + _newGroupDel = RowCursorUtils.GetIsNewGroupDelegate(row, schema.Group.Value.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); if (UnweightedCounters.GroupSummary) { - ValueGetter groupIdBuilder = RowCursorUtils.GetGetterAsStringBuilder(row, schema.Group.Index); + ValueGetter groupIdBuilder = RowCursorUtils.GetGetterAsStringBuilder(row, schema.Group.Value.Index); _groupSbUpdate = () => groupIdBuilder(ref _groupSb); } else @@ -932,15 +931,15 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics, private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckValue(schema.Label, nameof(schema), "Data must contain a label column"); - Host.CheckValue(schema.Group, nameof(schema), "Data must contain a group column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Data must contain a label column"); + Host.CheckParam(schema.Group.HasValue, nameof(schema), "Data must contain a group column"); // The ranking evaluator outputs the label, group key and score columns. - yield return schema.Group.Name; - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Group.Value.Name; + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Ranking); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RankerPerInstanceTransform.Ndcg; diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs index c4004a494e..e5af2e8bad 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs @@ -56,12 +56,12 @@ private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema) { var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); var t = score.Type; - if (t.IsVector || t.ItemType != NumberType.Float) - throw Host.Except("Score column '{0}' has type '{1}' but must be R4", score, t); - Host.Check(schema.Label != null, "Could not find the label column"); - t = schema.Label.Type; + if (t != NumberType.Float) + throw Host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "R4", t.ToString()); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); + t = schema.Label.Value.Type; if (t != NumberType.R4) - throw Host.Except("Label column '{0}' has type '{1}' but must be R4", schema.Label.Name, t); + throw Host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "R4", t.ToString()); } private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName) @@ -71,11 +71,10 @@ private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema) { - Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column"); + Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Could not find the label column"); var scoreInfo = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - Contracts.AssertValue(scoreInfo); - return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Name); + return new RegressionPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, schema.Label.Value.Name); } public override IEnumerable GetOverallMetricColumns() @@ -353,13 +352,13 @@ public RegressionMamlEvaluator(IHostEnvironment env, Arguments args) private protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) { Host.CheckValue(schema, nameof(schema)); - Host.CheckParam(schema.Label != null, nameof(schema), "Schema must contain a label column"); + Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column"); // The regression evaluator outputs the label and score columns. - yield return schema.Label.Name; - var scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), + yield return schema.Label.Value.Name; + var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Regression); - yield return scoreInfo.Name; + yield return scoreCol.Name; // Return the output columns. yield return RegressionPerInstanceEvaluator.L1; diff --git a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs index c10cb87903..9a5634de80 100644 --- a/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs +++ b/src/Microsoft.ML.Data/Evaluators/RegressionEvaluatorBase.cs @@ -197,17 +197,17 @@ private protected RegressionAggregatorBase(IHostEnvironment env, IRegressionLoss internal override void InitializeNextPass(Row row, RoleMappedSchema schema) { Contracts.Assert(PassNum < 1); - Contracts.AssertValue(schema.Label); + Contracts.Assert(schema.Label.HasValue); var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score); - _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); _scoreGetter = row.GetGetter(score.Index); Contracts.AssertValue(_labelGetter); Contracts.AssertValue(_scoreGetter); - if (schema.Weight != null) - _weightGetter = row.GetGetter(schema.Weight.Index); + if (schema.Weight.HasValue) + _weightGetter = row.GetGetter(schema.Weight.Value.Index); } public override void ProcessRow() diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index cd4e594223..9282af3a27 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -804,21 +804,21 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa env.CheckValue(ch, nameof(ch)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValue(data, nameof(data)); - ch.CheckParam(data.Schema.Label != null, nameof(data), "data must have a Label column"); + ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column"); var scored = ScoreUtils.GetScorer(predictor, data, env, null); if (caliTrainer.NeedsTraining) { int labelCol; - if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Name, out labelCol)) + if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol)) throw ch.Except("No label column found"); int scoreCol; if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol)) throw ch.Except("No score column found"); - int weightCol; - if (data.Schema.Weight == null || !scored.Schema.TryGetColumnIndex(data.Schema.Weight.Name, out weightCol)) - weightCol = -1; + int weightCol = -1; + if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx) + weightCol = weightIdx; ch.Info("Training calibrator."); using (var cursor = scored.GetRowCursor(col => col == labelCol || col == scoreCol || col == weightCol)) { diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index ee800ea4a0..46a8f29e6e 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -65,7 +65,7 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun if (trainSchema?.Label == null) return mapper; // We don't even have a label identified in a training schema. - var keyType = trainSchema.Schema[trainSchema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var keyType = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (keyType == null || !CanWrap(mapper, keyType)) return mapper; @@ -109,18 +109,17 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound env.AssertValue(mapper); env.AssertValue(trainSchema); env.Assert(mapper is ISchemaBoundRowMapper); + env.Assert(trainSchema.Label.HasValue); + var labelColumn = trainSchema.Label.Value; // Key values from the training schema label, will map to slot names of the score output. - var type = trainSchema.Schema[trainSchema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var type = labelColumn.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; env.AssertValue(type); env.Assert(type.IsVector); // Wrap the fetching of the metadata as a simple getter. - ValueGetter> getter = - (ref VBuffer value) => - { - trainSchema.Schema[trainSchema.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); - }; + ValueGetter> getter = (ref VBuffer value) => + labelColumn.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); return MultiClassClassifierScorer.LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.TrainingLabelValues, CanWrap); } diff --git a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs index ceea843335..9ccde8fab7 100644 --- a/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Scorers/FeatureContributionCalculationTransform.cs @@ -339,6 +339,7 @@ private sealed class RowMapper : ISchemaBoundRowMapper public RoleMappedSchema InputRoleMappedSchema { get; } public Schema InputSchema => InputRoleMappedSchema.Schema; + private Schema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value; public Schema OutputSchema { get; } @@ -350,7 +351,7 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s _env = env; _env.AssertValue(schema); _env.AssertValue(parent); - _env.AssertValue(schema.Feature); + _env.Assert(schema.Feature.HasValue); _parent = parent; InputRoleMappedSchema = schema; var genericMapper = parent.GenericMapper.Bind(_env, schema); @@ -361,16 +362,16 @@ public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema s var builder = new SchemaBuilder(); builder.AddColumn(DefaultColumnNames.FeatureContributions, TextType.Instance, null); _outputSchema = builder.GetSchema(); - if (InputSchema[InputRoleMappedSchema.Feature.Index].HasSlotNames(InputRoleMappedSchema.Feature.Type.VectorSize)) - InputSchema[InputRoleMappedSchema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); + if (FeatureColumn.HasSlotNames(FeatureColumn.Type.VectorSize)) + FeatureColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); else - _slotNames = VBufferUtils.CreateEmpty>(InputRoleMappedSchema.Feature.Type.VectorSize); + _slotNames = VBufferUtils.CreateEmpty>(FeatureColumn.Type.VectorSize); } else { _outputSchema = Schema.Create(new FeatureContributionSchema(_env, DefaultColumnNames.FeatureContributions, - new VectorType(NumberType.R4, schema.Feature.Type as VectorType), - InputSchema, InputRoleMappedSchema.Feature.Index)); + new VectorType(NumberType.R4, FeatureColumn.Type as VectorType), + InputSchema, FeatureColumn.Index)); } _outputGenericSchema = _genericRowMapper.OutputSchema; @@ -385,7 +386,7 @@ public Func GetDependencies(Func predicate) for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == FeatureColumn.Index; } return col => false; } @@ -400,8 +401,8 @@ public Row GetRow(Row input, Func active) if (active(totalColumnsCount - 1)) { getters[totalColumnsCount - 1] = _parent.Stringify - ? _parent.GetTextContributionGetter(input, InputRoleMappedSchema.Feature.Index, _slotNames) - : _parent.GetContributionGetter(input, InputRoleMappedSchema.Feature.Index); + ? _parent.GetTextContributionGetter(input, FeatureColumn.Index, _slotNames) + : _parent.GetContributionGetter(input, FeatureColumn.Index); } var genericRow = _genericRowMapper.GetRow(input, GetGenericPredicate(active)); @@ -421,7 +422,7 @@ public Func GetGenericPredicate(Func predicate) public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name); } } diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 7de5b0cb5a..e052a55870 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -371,9 +371,9 @@ private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoun // them as slot name metadata. But there are a number of conditions for this to actually // happen, so we test those here. If these are not - if (trainSchema == null || trainSchema.Label == null) + if (trainSchema?.Label == null) return mapper; // We don't even have a label identified in a training schema. - var keyType = trainSchema.Schema[trainSchema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var keyType = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (keyType == null || !CanWrap(mapper, keyType)) return mapper; @@ -404,6 +404,7 @@ private static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType) var outSchema = mapper.OutputSchema; int scoreIdx; + var scoreCol = outSchema.GetColumnOrNull(MetadataUtils.Const.ScoreValueKind.Score); if (!outSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx)) return false; // The mapper doesn't even publish a score column to attach the metadata to. if (outSchema[scoreIdx].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type != null) @@ -422,7 +423,7 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound env.Assert(mapper is ISchemaBoundRowMapper); // Key values from the training schema label, will map to slot names of the score output. - var type = trainSchema.Schema[trainSchema.Label.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var type = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; env.AssertValue(type); env.Assert(type.IsVector); @@ -430,7 +431,7 @@ private static ISchemaBoundMapper WrapCore(IHostEnvironment env, ISchemaBound ValueGetter> getter = (ref VBuffer value) => { - trainSchema.Schema[trainSchema.Label.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); + trainSchema.Label.Value.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref value); }; return LabelNameBindableMapper.CreateBound(env, (ISchemaBoundRowMapper)mapper, type as VectorType, getter, MetadataUtils.Kinds.SlotNames, CanWrap); diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index e6142e3b50..c14893f1da 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -122,10 +122,9 @@ ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSc using (var ch = env.Register("SchemaBindableWrapper").Start("Bind")) { ch.CheckValue(schema, nameof(schema)); - if (schema.Feature != null) + if (schema.Feature?.Type is ColumnType type) { // Ensure that the feature column type is compatible with the needed input type. - var type = schema.Feature.Type; var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorType(NumberType.Float); if (type != typeIn) { @@ -199,7 +198,7 @@ public SingleValueRowMapper(RoleMappedSchema schema, SchemaBindablePredictorWrap { Contracts.AssertValue(schema); Contracts.AssertValue(parent); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(outputSchema.Count == 1); _parent = parent; @@ -212,14 +211,14 @@ public Func GetDependencies(Func predicate) for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == InputRoleMappedSchema.Feature.Value.Index; } return col => false; } public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Value.Name); } public Schema InputSchema => InputRoleMappedSchema.Schema; @@ -231,7 +230,7 @@ public Row GetRow(Row input, Func predicate) var getters = new Delegate[1]; if (predicate(0)) - getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Index); + getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Value.Index); return new SimpleRow(OutputSchema, input, getters); } } @@ -290,11 +289,11 @@ private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSch Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is ISingleCanSavePfa); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 1); // Score. var mapper = (ISingleCanSavePfa)ValueMapper; // If the features column was not produced, we must hide the outputs. - var featureToken = ctx.TokenOrNullForName(schema.Feature.Name); + var featureToken = ctx.TokenOrNullForName(schema.Feature.Value.Name); if (featureToken == null) ctx.Hide(outputNames); var scoreToken = mapper.SaveAsPfa(ctx, featureToken); @@ -306,15 +305,14 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is ISingleCanSaveOnnx); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) <= 2); // PredictedLabel and/or Score. var mapper = (ISingleCanSaveOnnx)ValueMapper; - if (!ctx.ContainsColumn(schema.Feature.Name)) + string featName = schema.Feature.Value.Name; + if (!ctx.ContainsColumn(featName)) return false; - - Contracts.Assert(ctx.ContainsColumn(schema.Feature.Name)); - - return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)); + Contracts.Assert(ctx.ContainsColumn(featName)); + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); } private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) @@ -401,11 +399,11 @@ private protected override void SaveAsPfaCore(BoundPfaContext ctx, RoleMappedSch Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Assert(ValueMapper is IDistCanSavePfa); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 2); // Score and prob. var mapper = (IDistCanSavePfa)ValueMapper; // If the features column was not produced, we must hide the outputs. - string featureToken = ctx.TokenOrNullForName(schema.Feature.Name); + string featureToken = ctx.TokenOrNullForName(schema.Feature.Value.Name); if (featureToken == null) ctx.Hide(outputNames); @@ -423,15 +421,14 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema var mapper = ValueMapper as ISingleCanSaveOnnx; Contracts.CheckValue(mapper, nameof(mapper)); - Contracts.AssertValue(schema.Feature); + Contracts.Assert(schema.Feature.HasValue); Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity. - if (!ctx.ContainsColumn(schema.Feature.Name)) + var featName = schema.Feature.Value.Name; + if (!ctx.ContainsColumn(featName)) return false; - - Contracts.Assert(ctx.ContainsColumn(schema.Feature.Name)); - - return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)); + Contracts.Assert(ctx.ContainsColumn(featName)); + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); } private void CheckValid(out IValueMapperDist distMapper) @@ -481,15 +478,13 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto Contracts.AssertValue(parent); Contracts.Assert(parent._distMapper != null); Contracts.AssertValue(schema); - Contracts.AssertValueOrNull(schema.Feature); _parent = parent; InputRoleMappedSchema = schema; OutputSchema = Schema.Create(new BinaryClassifierSchema()); - if (schema.Feature != null) + if (schema.Feature?.Type is ColumnType typeSrc) { - var typeSrc = InputRoleMappedSchema.Feature.Type; Contracts.Check(typeSrc.IsKnownSizeVector && typeSrc.ItemType == NumberType.Float, "Invalid feature column type"); } @@ -499,8 +494,8 @@ public Func GetDependencies(Func predicate) { for (int i = 0; i < OutputSchema.Count; i++) { - if (predicate(i) && InputRoleMappedSchema.Feature != null) - return col => col == InputRoleMappedSchema.Feature.Index; + if (predicate(i) && InputRoleMappedSchema.Feature?.Index is int idx) + return col => col == idx; } return col => false; } @@ -519,7 +514,7 @@ private Delegate[] CreateGetters(Row input, bool[] active) if (active[0] || active[1]) { // Put all captured locals at this scope. - var featureGetter = InputRoleMappedSchema.Feature != null ? input.GetGetter>(InputRoleMappedSchema.Feature.Index) : null; + var featureGetter = InputRoleMappedSchema.Feature?.Index is int idx ? input.GetGetter>(idx) : null; Float prob = 0; Float score = 0; long cachedPosition = -1; diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 77268dd4c9..18e8850492 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; @@ -48,11 +49,11 @@ public static void CheckFeatureFloatVector(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Feature; - if (col == null) + if (!data.Schema.Feature.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a feature column."); - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); - if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float) + var col = data.Schema.Feature.Value; + Contracts.Assert(!col.IsHidden); + if (!(col.Type is VectorType vecType && vecType.Size > 0 && vecType.ItemType == NumberType.Float)) throw Contracts.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); } @@ -65,11 +66,12 @@ public static void CheckFeatureFloatVector(this RoleMappedData data, out int len // If the above function is generalized, this needs to be as well. Contracts.AssertValue(data); - Contracts.Assert(data.Schema.Feature != null); - Contracts.Assert(!data.Schema.Schema[data.Schema.Feature.Index].IsHidden); - Contracts.Assert(data.Schema.Feature.Type.IsKnownSizeVector); - Contracts.Assert(data.Schema.Feature.Type.ItemType == NumberType.Float); - length = data.Schema.Feature.Type.VectorSize; + Contracts.Assert(data.Schema.Feature.HasValue); + var col = data.Schema.Feature.Value; + Contracts.Assert(!col.IsHidden); + Contracts.Assert(col.Type.IsKnownSizeVector); + Contracts.Assert(col.Type.ItemType == NumberType.Float); + length = col.Type.VectorSize; } /// @@ -79,11 +81,11 @@ public static void CheckBinaryLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); - if (!col.Type.IsBool && col.Type != NumberType.R4 && col.Type != NumberType.R8 && col.Type.KeyCount != 2) + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type != BoolType.Instance && col.Type != NumberType.R4 && col.Type != NumberType.R8 && !(col.Type is KeyType keyType && keyType.Count == 2)) { if (col.Type.IsKey) { @@ -113,9 +115,9 @@ public static void CheckRegressionLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); + var col = data.Schema.Label.Value; Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); if (col.Type != NumberType.R4 && col.Type != NumberType.R8) { @@ -133,13 +135,13 @@ public static void CheckMultiClassLabel(this RoleMappedData data, out int count) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); - if (col.Type.KeyCount > 0) + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type is KeyType keyType && keyType.Count > 0) { - count = col.Type.KeyCount; + count = keyType.Count; return; } @@ -179,10 +181,10 @@ public static void CheckMultiOutputRegressionLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Label; - if (col == null) + if (!data.Schema.Label.HasValue) throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); + var col = data.Schema.Label.Value; + Contracts.Assert(!col.IsHidden); if (!col.Type.IsKnownSizeVector || col.Type.ItemType != NumberType.Float) throw Contracts.ExceptParam(nameof(data), "Training label column '{0}' must be a known-size vector of R4, but has type: {1}.", col.Name, col.Type); } @@ -191,10 +193,10 @@ public static void CheckOptFloatWeight(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Weight; - if (col == null) + if (!data.Schema.Weight.HasValue) return; - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); + var col = data.Schema.Weight.Value; + Contracts.Assert(!col.IsHidden); if (col.Type != NumberType.R4 && col.Type != NumberType.R8) throw Contracts.ExceptParam(nameof(data), "Training weight column '{0}' must be of floating point numeric type, but has type: {1}.", col.Name, col.Type); } @@ -203,11 +205,11 @@ public static void CheckOptGroup(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); - var col = data.Schema.Group; - if (col == null) + if (!data.Schema.Group.HasValue) return; - Contracts.Assert(!data.Schema.Schema[col.Index].IsHidden); - if (col.Type.IsKey) + var col = data.Schema.Group.Value; + Contracts.Assert(!col.IsHidden); + if (col.Type is KeyType) return; throw Contracts.ExceptParam(nameof(data), "Training group column '{0}' type is invalid: {1}. Must be Key type.", col.Name, col.Type); } @@ -249,11 +251,11 @@ public static RowCursor[] CreateRowCursorSet(this RoleMappedData data, CursOpt opt, int n, Random rand, IEnumerable extraCols = null) => data.Data.GetRowCursorSet(CreatePredicate(data, opt, extraCols), n, rand); - private static void AddOpt(HashSet cols, ColumnInfo info) + private static void AddOpt(HashSet cols, Schema.Column? info) { Contracts.AssertValue(cols); - if (info != null) - cols.Add(info.Index); + if (info.HasValue) + cols.Add(info.Value.Index); } /// @@ -264,9 +266,9 @@ public static ValueGetter> GetFeatureFloatVectorGetter(this Row r Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!"); - Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column"); + Contracts.CheckParam(schema.Feature.HasValue, nameof(schema), "Missing feature column"); - return row.GetGetter>(schema.Feature.Index); + return row.GetGetter>(schema.Feature.Value.Index); } /// @@ -287,9 +289,9 @@ public static ValueGetter GetLabelFloatGetter(this Row row, RoleMappedSch Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!"); - Contracts.CheckParam(schema.Label != null, nameof(schema), "Missing label column"); + Contracts.CheckParam(schema.Label.HasValue, nameof(schema), "Missing label column"); - return RowCursorUtils.GetLabelGetter(row, schema.Label.Index); + return RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index); } /// @@ -310,12 +312,11 @@ public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappe Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(schema.Schema == row.Schema, "schemas don't match!"); - Contracts.CheckValueOrNull(schema.Weight); var col = schema.Weight; - if (col == null) + if (!col.HasValue) return null; - return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Index); + return RowCursorUtils.GetGetterAs(NumberType.Float, row, col.Value.Index); } public static ValueGetter GetOptWeightFloatGetter(this Row row, RoleMappedData data) @@ -332,12 +333,11 @@ public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedSchem Contracts.CheckValue(row, nameof(row)); Contracts.CheckValue(schema, nameof(schema)); Contracts.Check(schema.Schema == row.Schema, "schemas don't match!"); - Contracts.CheckValueOrNull(schema.Group); var col = schema.Group; - if (col == null) + if (!col.HasValue) return null; - return RowCursorUtils.GetGetterAs(NumberType.U8, row, col.Index); + return RowCursorUtils.GetGetterAs(NumberType.U8, row, col.Value.Index); } public static ValueGetter GetOptGroupGetter(this Row row, RoleMappedData data) diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs index 56addbdb52..e4ca0348f6 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs @@ -70,27 +70,6 @@ internal interface IColumnFunction : ICanSaveModel NormalizingTransformer.NormalizerModelParametersBase GetNormalizerModelParams(); } - internal static class NormalizeUtils - { - /// - /// Returns whether the feature column in the schema is indicated to be normalized. If the features column is not - /// specified on the schema, then this will return null. - /// - /// The role-mapped schema to query - /// Returns null if does not have - /// defined, and otherwise returns a Boolean value as returned from - /// on that feature column - /// - public static bool? FeaturesAreNormalized(this RoleMappedSchema schema) - { - // REVIEW: The role mapped data has the ability to have multiple columns fill the role of features, which is - // useful in some trainers that are nonetheless parameteric and can therefore benefit from normalization. - Contracts.CheckValue(schema, nameof(schema)); - var featInfo = schema.Feature; - return featInfo == null ? default(bool?) : schema.Schema[featInfo.Index].IsNormalized(); - } - } - /// /// This contains entry-point definitions related to . /// diff --git a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs index 66a6ff165e..5515695bdf 100644 --- a/src/Microsoft.ML.Ensemble/EnsembleUtils.cs +++ b/src/Microsoft.ML.Ensemble/EnsembleUtils.cs @@ -18,17 +18,18 @@ public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, Bit { Contracts.AssertValue(host); Contracts.AssertValue(data); - Contracts.Assert(data.Schema.Feature != null); + Contracts.Assert(data.Schema.Feature.HasValue); Contracts.AssertValue(features); + var featCol = data.Schema.Feature.Value; - var type = data.Schema.Feature.Type; + var type = featCol.Type; Contracts.Assert(features.Length == type.VectorSize); int card = Utils.GetCardinality(features); if (card == type.VectorSize) return data; // REVIEW: This doesn't preserve metadata on the features column. Should it? - var name = data.Schema.Feature.Name; + var name = featCol.Name; var view = LambdaColumnMapper.Create( host, "FeatureSelector", data.Data, name, name, type, type, (in VBuffer src, ref VBuffer dst) => SelectFeatures(in src, features, card, ref dst)); diff --git a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs index 97ddf9ef43..2becb019e0 100644 --- a/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs +++ b/src/Microsoft.ML.Ensemble/PipelineEnsemble.cs @@ -161,29 +161,31 @@ internal override Delegate CreateScoreGetter(Row input, Func mapperPr public ValueGetter GetLabelGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); - Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label != null, "Mapper was not trained using a label column"); + Parent.Host.Check(Mappers[i].InputRoleMappedSchema.Label.HasValue, "Mapper was not trained using a label column"); + var labelCol = Mappers[i].InputRoleMappedSchema.Label.Value; // The label should be in the output row of the i'th pipeline - var pipelineRow = BoundPipelines[i].GetRow(input, col => col == Mappers[i].InputRoleMappedSchema.Label.Index); + var pipelineRow = BoundPipelines[i].GetRow(input, col => col == labelCol.Index); disposer = pipelineRow.Dispose; - return RowCursorUtils.GetLabelGetter(pipelineRow, Mappers[i].InputRoleMappedSchema.Label.Index); + return RowCursorUtils.GetLabelGetter(pipelineRow, labelCol.Index); } public ValueGetter GetWeightGetter(Row input, int i, out Action disposer) { Parent.Host.Assert(0 <= i && i < Mappers.Length); - if (Mappers[i].InputRoleMappedSchema.Weight == null) + if (!Mappers[i].InputRoleMappedSchema.Weight.HasValue) { ValueGetter weight = (ref float dst) => dst = 1; disposer = null; return weight; } + var weightCol = Mappers[i].InputRoleMappedSchema.Weight.Value; // The weight should be in the output row of the i'th pipeline if it exists. - var inputPredicate = Mappers[i].GetDependencies(col => col == Mappers[i].InputRoleMappedSchema.Weight.Index); + var inputPredicate = Mappers[i].GetDependencies(col => col == weightCol.Index); var pipelineRow = BoundPipelines[i].GetRow(input, inputPredicate); disposer = pipelineRow.Dispose; - return pipelineRow.GetGetter(Mappers[i].InputRoleMappedSchema.Weight.Index); + return pipelineRow.GetGetter(weightCol.Index); } } @@ -590,22 +592,22 @@ protected static int CheckLabelColumn(IHostEnvironment env, PredictorModel[] mod var model = models[0]; var edv = new EmptyDataView(env, model.TransformModel.InputSchema); model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred); - var labelInfo = rmd.Schema.Label; - if (labelInfo == null) + if (!rmd.Schema.Label.HasValue) throw env.Except("Training schema for model 0 does not have a label column"); + var labelCol = rmd.Schema.Label.Value; - var labelType = rmd.Schema.Schema[rmd.Schema.Label.Index].Type; + var labelType = labelCol.Type; if (!labelType.IsKey) return CheckNonKeyLabelColumnCore(env, pred, models, isBinary, labelType); if (isBinary && labelType.KeyCount != 2) throw env.Except("Label is not binary"); var schema = rmd.Schema.Schema; - var mdType = schema[labelInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (mdType == null || !mdType.IsKnownSizeVector) throw env.Except("Label column of type key must have a vector of key values metadata"); - return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, (KeyType)labelType, schema, labelInfo.Index, mdType); + return Utils.MarshalInvoke(CheckKeyLabelColumnCore, mdType.ItemType.RawType, env, models, (KeyType)labelType, schema, labelCol.Index, mdType); } // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the @@ -653,18 +655,19 @@ private static int CheckKeyLabelColumnCore(IHostEnvironment env, PredictorMod var model = models[i]; var edv = new EmptyDataView(env, model.TransformModel.InputSchema); model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred); - var labelInfo = rmd.Schema.Label; - if (labelInfo == null) + var labelInfo = rmd.Schema.Label.HasValue; + if (!rmd.Schema.Label.HasValue) throw env.Except("Training schema for model {0} does not have a label column", i); + var labelCol = rmd.Schema.Label.Value; - var curLabelType = rmd.Schema.Schema[rmd.Schema.Label.Index].Type as KeyType; + var curLabelType = labelCol.Type as KeyType; if (!labelType.Equals(curLabelType)) throw env.Except("Label column of model {0} has different type than model 0", i); - var mdType = rmd.Schema.Schema[labelInfo.Index].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; if (!mdType.Equals(keyValuesType)) throw env.Except("Label column of model {0} has different key value type than model 0", i); - rmd.Schema.Schema[labelInfo.Index].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref curLabelNames); + labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref curLabelNames); if (!AreEqual(in labelNames, in curLabelNames)) throw env.Except("Label of model {0} has different values than model 0", i); } diff --git a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs index c22c6e8c48..c9b4890512 100644 --- a/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/FeatureSelector/RandomFeatureSelector.cs @@ -50,7 +50,7 @@ public Subset SelectFeatures(RoleMappedData data, Random rand) _host.CheckValue(data, nameof(data)); data.CheckFeatureFloatVector(); - var type = data.Schema.Feature.Type; + var type = data.Schema.Feature.Value.Type; int len = type.VectorSize; var features = new BitArray(len); for (int j = 0; j < len; j++) diff --git a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs index 9eb51c090a..32fe976551 100644 --- a/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs +++ b/src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseSubModelSelector.cs @@ -106,27 +106,27 @@ public virtual void CalculateMetrics(FeatureSubsetModel t == NumberType.Float); - if (probInfo != null) - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probInfo.Name); + var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, scoredSchema, null, nameof(BinaryClassifierMamlEvaluator.Arguments.ProbabilityColumn), + scoreCol.Index, MetadataUtils.Const.ScoreValueKind.Probability, NumberType.Float.Equals); + if (probCol.HasValue) + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probCol.Value.Name); yield break; case PredictionKind.Regression: - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Name); - scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn), + yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name); + scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(RegressionMamlEvaluator.Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.Regression); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); yield break; case PredictionKind.MultiClassClassification: - yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Name); - scoreInfo = EvaluateUtils.GetScoreColumnInfo(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn), + yield return RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, testSchema.Label.Value.Name); + scoreCol = EvaluateUtils.GetScoreColumn(Host, scoredSchema, null, nameof(MultiClassMamlEvaluator.Arguments.ScoreColumn), MetadataUtils.Const.ScoreColumnKind.MultiClassClassification); - yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreInfo.Name); + yield return RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreCol.Name); yield break; default: throw Host.Except("Unrecognized prediction kind '{0}'", PredictionKind); diff --git a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index 9826f4a5e7..24de18e152 100644 --- a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -149,7 +149,7 @@ private static IDataView ApplyConvert(List return viewTrain; } - private static List ConvertFeatures(ColumnInfo[] feats, HashSet featNames, List> concatNames, IChannel ch, + private static List ConvertFeatures(IEnumerable feats, HashSet featNames, List> concatNames, IChannel ch, out List cvt, out int errCount) { Contracts.AssertValue(feats); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index c61b809813..6da92db4b5 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -212,7 +212,7 @@ private void Initialize(IHostEnvironment env) private protected void ConvertData(RoleMappedData trainData) { - MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Index, out CategoricalFeatures); + MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, trainData.Schema.Feature.Value.Index, out CategoricalFeatures); var useTranspose = UseTranspose(Args.DiskTranspose, trainData) && (ValidData == null || UseTranspose(Args.DiskTranspose, ValidData)); var instanceConverter = new ExamplesToFastTreeBins(Host, Args.MaxBins, useTranspose, !Args.FeatureFlocks, Args.MinDocumentsInLeafs, GetMaxLabel()); @@ -227,13 +227,13 @@ private protected void ConvertData(RoleMappedData trainData) private bool UseTranspose(bool? useTranspose, RoleMappedData data) { Host.AssertValue(data); - Host.AssertValue(data.Schema.Feature); + Host.Assert(data.Schema.Feature.HasValue); if (useTranspose.HasValue) return useTranspose.Value; ITransposeDataView td = data.Data as ITransposeDataView; - return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Index) != null; + return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Value.Index) != null; } protected void TrainCore(IChannel ch) @@ -949,11 +949,11 @@ private DataConverter(RoleMappedData data, IHost host, Double[][] binUpperBounds Contracts.AssertValue(host, "host"); Host = host; Host.CheckValue(data, nameof(data)); - data.CheckFeatureFloatVector(); + data.CheckFeatureFloatVector(out int featLen); data.CheckOptFloatWeight(); data.CheckOptGroup(); - NumFeatures = data.Schema.Feature.Type.VectorSize; + NumFeatures = featLen; if (binUpperBounds != null) { Host.AssertValue(binUpperBounds); @@ -1320,14 +1320,15 @@ public override Dataset GetDataset() return _dataset; } - private static int AddColumnIfNeeded(ColumnInfo info, List toTranspose) + private static int AddColumnIfNeeded(Schema.Column? info, List toTranspose) { - if (info == null) + if (!info.HasValue) return -1; // It is entirely possible that a single column could have two roles, // and so be added twice, but this case is handled by the transposer. - toTranspose.Add(info.Index); - return info.Index; + var idx = info.Value.Index; + toTranspose.Add(idx); + return idx; } private ValueMapper, VBuffer> GetCopier(ColumnType itemType1, ColumnType itemType2) @@ -1360,10 +1361,7 @@ private ValueMapper, VBuffer> GetCopier(ColumnType itemT private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxBins, IParallelTraining parallelTraining) { Host.AssertValue(examples); - Host.AssertValue(examples.Schema.Feature); - Host.AssertValueOrNull(examples.Schema.Label); - Host.AssertValueOrNull(examples.Schema.Group); - Host.AssertValueOrNull(examples.Schema.Weight); + Host.Assert(examples.Schema.Feature.HasValue); if (parallelTraining == null) Host.AssertValue(BinUpperBounds); @@ -1388,8 +1386,8 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB data = new LabelConvertTransform(Host, convArgs, data); } // Convert the group column, if one exists. - if (examples.Schema.Group != null) - data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(examples.Schema.Group.Name, examples.Schema.Group.Name, DataKind.U8)).Transform(data); + if (examples.Schema.Group?.Name is string groupName) + data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(groupName, groupName, DataKind.U8)).Transform(data); // Since we've passed it through a few transforms, reconstitute the mapping on the // newly transformed data. @@ -1648,7 +1646,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB else { if (groupIdx >= 0) - ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", examples.Schema.Group.Name); + ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", examples.Schema.Group.Value.Name); const int queryChunkSize = 100; qids = new ulong[(numExamples - 1) / queryChunkSize + 1]; boundaries = new int[qids.Length + 1]; @@ -1860,7 +1858,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to else { if (_data.Schema.Group != null) - ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", _data.Schema.Group.Name); + ch.Warning("This is not ranking problem, Group Id '{0}' column will be ignored", _data.Schema.Group.Value.Name); } using (var cursor = new FloatLabelCursor(_data, curOptions)) { diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 4ac4641da3..533af9c511 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -169,7 +169,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index dc6073f324..6fcb983311 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -125,7 +125,7 @@ private protected override FastTreeRankingModelParameters TrainModelCore(TrainCo var maxLabel = GetLabelGains().Length - 1; ConvertData(trainData); TrainCore(ch); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; } return new FastTreeRankingModelParameters(Host, TrainedEnsemble, FeatureCount, InnerArgs); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 6130ac669b..9bccf98f9a 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -97,7 +97,7 @@ private protected override FastTreeRegressionModelParameters TrainModelCore(Trai trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index af3f1d52f7..50f75c16d9 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -100,7 +100,7 @@ private protected override FastTreeTweedieModelParameters TrainModelCore(TrainCo trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 89e5521147..a99712785c 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -227,7 +227,7 @@ private protected void TrainBase(TrainContext context) DefineScoreTrackers(); if (HasValidSet) DefinePruningTest(); - InputLength = context.TrainingSet.Schema.Feature.Type.ValueCount; + InputLength = context.TrainingSet.Schema.Feature.Value.Type.ValueCount; TrainCore(ch); } @@ -264,13 +264,11 @@ private void ConvertData(RoleMappedData trainData, RoleMappedData validationData private bool UseTranspose(bool? useTranspose, RoleMappedData data) { Host.AssertValue(data); - Host.AssertValue(data.Schema.Feature); + Host.Assert(data.Schema.Feature.HasValue); if (useTranspose.HasValue) return useTranspose.Value; - - ITransposeDataView td = data.Data as ITransposeDataView; - return td != null && td.TransposeSchema.GetSlotType(data.Schema.Feature.Index) != null; + return data.Data is ITransposeDataView td && td.TransposeSchema.GetSlotType(data.Schema.Feature.Value.Index) != null; } private void TrainCore(IChannel ch) @@ -1118,11 +1116,12 @@ public Context(IChannel ch, GamPredictorBase pred, RoleMappedData data, IEvaluat _pred = pred; _data = data; var schema = _data.Schema; - ch.Check(schema.Feature.Type.ValueCount == _pred._inputLength); + var featCol = schema.Feature.Value; + ch.Check(featCol.Type.ValueCount == _pred._inputLength); - int len = schema.Feature.Type.ValueCount; - if (schema.Schema[schema.Feature.Index].HasSlotNames(len)) - schema.Schema[schema.Feature.Index].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _featNames); + int len = featCol.Type.ValueCount; + if (featCol.HasSlotNames(len)) + featCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _featNames); else _featNames = VBufferUtils.CreateEmpty>(len); diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 29fb13f6ee..e4710727cd 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -183,7 +183,7 @@ private protected override IPredictorWithFeatureWeights TrainModelCore(Tr trainData.CheckBinaryLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 8923a06ca6..efd12235f5 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -201,7 +201,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr trainData.CheckRegressionLabel(); trainData.CheckFeatureFloatVector(); trainData.CheckOptFloatWeight(); - FeatureCount = trainData.Schema.Feature.Type.ValueCount; + FeatureCount = trainData.Schema.Feature.Value.Type.ValueCount; ConvertData(trainData); TrainCore(ch); } diff --git a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs index b1a21f9541..ab320319f1 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs @@ -413,8 +413,8 @@ internal sealed class FeaturesToContentMap public FeaturesToContentMap(RoleMappedSchema schema) { Contracts.AssertValue(schema); - var feat = schema.Feature; - Contracts.AssertValue(feat); + Contracts.Assert(schema.Feature.HasValue); + var feat = schema.Feature.Value; Contracts.Assert(feat.Type.ValueCount > 0); var sch = schema.Schema; diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 27951364fc..3e0fc0bb76 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -176,6 +176,7 @@ private void IsNormalized(int iinfo, ref bool dst) public Schema InputSchema => InputRoleMappedSchema.Schema; public Schema OutputSchema { get; } + private Schema.Column FeatureColumn => InputRoleMappedSchema.Feature.Value; public ISchemaBindableMapper Bindable => _owner; @@ -185,7 +186,7 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper Contracts.AssertValue(ectx); ectx.AssertValue(owner); ectx.AssertValue(schema); - ectx.AssertValue(schema.Feature); + ectx.Assert(schema.Feature.HasValue); _ectx = ectx; @@ -229,7 +230,7 @@ private Delegate[] CreateGetters(Row input, Func predicate) if (!treeValueActive && !leafIdActive && !pathIdActive) return delegates; - var state = new State(_ectx, input, _owner._ensemble, _owner._totalLeafCount, InputRoleMappedSchema.Feature.Index); + var state = new State(_ectx, input, _owner._ensemble, _owner._totalLeafCount, FeatureColumn.Index); // Get the tree value getter. if (treeValueActive) @@ -391,7 +392,7 @@ private void EnsureCachedPosition() public IEnumerable> GetInputColumnRoles() { - yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature.Name); + yield return RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name); } public Func GetDependencies(Func predicate) @@ -399,7 +400,7 @@ public Func GetDependencies(Func predicate) for (int i = 0; i < OutputSchema.Count; i++) { if (predicate(i)) - return col => col == InputRoleMappedSchema.Feature.Index; + return col => col == FeatureColumn.Index; } return col => false; } @@ -639,6 +640,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV ch.Trace("Creating scorer"); var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args); + Contracts.Assert(data.Schema.Feature.HasValue); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) @@ -647,11 +649,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type"); - if (vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.TrainedModelFile), "Predictor in model file expects {0} features, but data has {1} features", - vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); @@ -702,6 +704,7 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments RoleMappedData data = null; args.PredictorModel.PrepareData(env, input, out data, out var predictor2); ch.AssertValue(data); + ch.Assert(data.Schema.Feature.HasValue); ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. @@ -711,11 +714,11 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); - if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", - vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize); } ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index f8c0a8fdb5..ff758ee868 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -141,16 +141,16 @@ private protected override OlsLinearRegressionModelParameters TrainModelCore(Tra { ch.CheckValue(context, nameof(context)); var examples = context.TrainingSet; - ch.CheckParam(examples.Schema.Feature != null, nameof(examples), "Need a feature column"); - ch.CheckParam(examples.Schema.Label != null, nameof(examples), "Need a labelColumn column"); + ch.CheckParam(examples.Schema.Feature.HasValue, nameof(examples), "Need a feature column"); + ch.CheckParam(examples.Schema.Label.HasValue, nameof(examples), "Need a labelColumn column"); // The labelColumn type must be either Float or a key type based on int (if allowKeyLabels is true). - var typeLab = examples.Schema.Label.Type; + var typeLab = examples.Schema.Label.Value.Type; if (typeLab != NumberType.Float) throw ch.Except("Incompatible labelColumn column type {0}, must be {1}", typeLab, NumberType.Float); // The feature type must be a vector of Float. - var typeFeat = examples.Schema.Feature.Type; + var typeFeat = examples.Schema.Feature.Value.Type; if (!typeFeat.IsKnownSizeVector) throw ch.Except("Incompatible feature column type {0}, must be known sized vector of {1}", typeFeat, NumberType.Float); if (typeFeat.ItemType != NumberType.Float) diff --git a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs index b56d6d1fc3..1670e8a897 100644 --- a/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs @@ -124,13 +124,12 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa var roles = examples.Schema.GetColumnRoleNames(); var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); - ch.AssertValue(examplesToFeedTrain.Schema.Label); - ch.AssertValue(examplesToFeedTrain.Schema.Feature); - if (examples.Schema.Weight != null) - ch.AssertValue(examplesToFeedTrain.Schema.Weight); + ch.Assert(examplesToFeedTrain.Schema.Label.HasValue); + ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue); + if (examples.Schema.Weight.HasValue) + ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue); - int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; - ch.Check(numFeatures > 0, "Training set has no features, aborting training."); + ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training."); return examplesToFeedTrain; } @@ -637,7 +636,7 @@ public void Dispose() private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount) { - int numFeatures = data.Schema.Feature.Type.VectorSize; + int numFeatures = data.Schema.Feature.Value.Type.VectorSize; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); int numThreads = 1; ch.CheckUserArg(numThreads > 0, nameof(_args.NumberOfThreads), diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index bb544bcd1f..7361fe5bf3 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -142,11 +142,11 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); base.CheckDataValid(ch, data); - var labelType = data.Schema.Label.Type; - if (!(labelType.IsBool || labelType.IsKey || labelType == NumberType.R4)) + var labelType = data.Schema.Label.Value.Type; + if (!(labelType is BoolType || labelType is KeyType || labelType == NumberType.R4)) { throw ch.ExceptParam(nameof(data), - $"Label column '{data.Schema.Label.Name}' is of type '{labelType}', but must be key, boolean or R4."); + $"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType}', but must be key, boolean or R4."); } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 1435b02de0..72f89455fd 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -113,11 +113,11 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); base.CheckDataValid(ch, data); - var labelType = data.Schema.Label.Type; + var labelType = data.Schema.Label.Value.Type; if (!(labelType.IsBool || labelType.IsKey || labelType == NumberType.R4)) { throw ch.ExceptParam(nameof(data), - $"Label column '{data.Schema.Label.Name}' is of type '{labelType}', but must be key, boolean or R4."); + $"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType}', but must be key, boolean or R4."); } } @@ -143,7 +143,7 @@ private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData dat if (maxLabel >= _maxNumClass) throw ch.ExceptParam(nameof(data), $"max labelColumn cannot exceed {_maxNumClass}"); - if (data.Schema.Label.Type is KeyType keyType) + if (data.Schema.Label.Value.Type is KeyType keyType) { ch.Check(keyType.Contiguous, "labelColumn value should be contiguous"); if (hasNaNLabel) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index a92e593c82..f815ed251a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -123,18 +123,21 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) Host.AssertValue(ch); base.CheckDataValid(ch, data); // Check label types. - var labelType = data.Schema.Label.Type; - if (!(labelType.IsKey || labelType == NumberType.R4)) + var labelCol = data.Schema.Label.Value; + var labelType = labelCol.Type; + if (!(labelType is KeyType || labelType == NumberType.R4)) { throw ch.ExceptParam(nameof(data), - $"Label column '{data.Schema.Label.Name}' is of type '{labelType}', but must be key or R4."); + $"Label column '{labelCol.Name}' is of type '{labelType}', but must be key or R4."); } // Check group types. - var groupType = data.Schema.Group.Type; - if (!(groupType == NumberType.U4 || groupType.IsKey)) + ch.CheckParam(data.Schema.Group.HasValue, nameof(data), "Need a group column."); + var groupCol = data.Schema.Group.Value; + var groupType = groupCol.Type; + if (!(groupType == NumberType.U4 || groupType is KeyType)) { throw ch.ExceptParam(nameof(data), - $"Group column '{data.Schema.Group.Name}' is of type '{groupType}', but must be U4 or a Key."); + $"Group column '{groupCol.Name}' is of type '{groupType}', but must be U4 or a Key."); } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 26bac8c8aa..36518d9dca 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -131,11 +131,11 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); base.CheckDataValid(ch, data); - var labelType = data.Schema.Label.Type; + var labelType = data.Schema.Label.Value.Type; if (!(labelType.IsBool || labelType.IsKey || labelType == NumberType.R4)) { throw ch.ExceptParam(nameof(data), - $"Label column '{data.Schema.Label.Name}' is of type '{labelType}', but must be key, boolean or R4."); + $"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType}', but must be key, boolean or R4."); } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 35aaeda6af..e9a7de1a93 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -171,7 +171,7 @@ private void DisposeParallelTraining() private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) { data.CheckFeatureFloatVector(); - ch.CheckParam(data.Schema.Label != null, nameof(data), "Need a label column"); + ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Need a label column"); } protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false) @@ -289,7 +289,7 @@ private CategoricalMetaData GetCategoricalMetaData(IChannel ch, RoleMappedData t trainData.Schema.Schema.TryGetColumnIndex(DefaultColumnNames.Features, out int featureIndex); MetadataUtils.TryGetCategoricalFeatureIndices(trainData.Schema.Schema, featureIndex, out categoricalFeatures); } - var colType = trainData.Schema.Feature.Type; + var colType = trainData.Schema.Feature.Value.Type; int rawNumCol = colType.VectorSize; FeatureCount = rawNumCol; catMetaData.TotalCats = 0; diff --git a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs index 300adb2493..2a4f5eb782 100644 --- a/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs +++ b/src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs @@ -299,23 +299,20 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, ch.AssertValue(data); ch.AssertValueOrNull(validData); - ColumnInfo matrixColumnIndexColInfo; - ColumnInfo matrixRowIndexColInfo; - ColumnInfo validMatrixColumnIndexColInfo = null; - ColumnInfo validMatrixRowIndexColInfo = null; - - ch.CheckValue(data.Schema.Label, nameof(data), "Input data did not have a unique label"); - RecommenderUtils.CheckAndGetMatrixIndexColumns(data, out matrixColumnIndexColInfo, out matrixRowIndexColInfo, isDecode: false); - if (data.Schema.Label.Type != NumberType.R4 && data.Schema.Label.Type != NumberType.R8) - throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type); + ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Input data did not have a unique label"); + RecommenderUtils.CheckAndGetMatrixIndexColumns(data, out var matrixColumnIndexColInfo, out var matrixRowIndexColInfo, isDecode: false); + var labelCol = data.Schema.Label.Value; + if (labelCol.Type != NumberType.R4 && labelCol.Type != NumberType.R8) + throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", labelCol.Name, labelCol.Type); MatrixFactorizationPredictor predictor; if (validData != null) { ch.CheckValue(validData, nameof(validData)); - ch.CheckValue(validData.Schema.Label, nameof(validData), "Input validation data did not have a unique label"); - RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out validMatrixColumnIndexColInfo, out validMatrixRowIndexColInfo, isDecode: false); - if (validData.Schema.Label.Type != NumberType.R4 && validData.Schema.Label.Type != NumberType.R8) - throw ch.Except("Column '{0}' for validation label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type); + ch.CheckParam(validData.Schema.Label.HasValue, nameof(validData), "Input validation data did not have a unique label"); + RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out var validMatrixColumnIndexColInfo, out var validMatrixRowIndexColInfo, isDecode: false); + var validLabelCol = validData.Schema.Label.Value; + if (validLabelCol.Type != NumberType.R4 && validLabelCol.Type != NumberType.R8) + throw ch.Except("Column '{0}' for validation label should be floating point, but is instead {1}", validLabelCol.Name, validLabelCol.Type); if (!matrixColumnIndexColInfo.Type.Equals(validMatrixColumnIndexColInfo.Type)) { @@ -335,10 +332,10 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, ch.Assert(colCount > 0); // Checks for equality on the validation set ensure it is correct here. - using (var cursor = data.Data.GetRowCursor(c => c == matrixColumnIndexColInfo.Index || c == matrixRowIndexColInfo.Index || c == data.Schema.Label.Index)) + using (var cursor = data.Data.GetRowCursor(c => c == matrixColumnIndexColInfo.Index || c == matrixRowIndexColInfo.Index || c == data.Schema.Label.Value.Index)) { // LibMF works only over single precision floats, but we want to be able to consume either. - var labGetter = RowCursorUtils.GetGetterAs(NumberType.R4, cursor, data.Schema.Label.Index); + var labGetter = RowCursorUtils.GetGetterAs(NumberType.R4, cursor, data.Schema.Label.Value.Index); var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs(NumberType.U4, cursor, matrixColumnIndexColInfo.Index); var matrixRowIndexGetter = RowCursorUtils.GetGetterAs(NumberType.U4, cursor, matrixRowIndexColInfo.Index); @@ -353,10 +350,11 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, } else { + RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out var validMatrixColumnIndexColInfo, out var validMatrixRowIndexColInfo, isDecode: false); using (var validCursor = validData.Data.GetRowCursor( - c => c == validMatrixColumnIndexColInfo.Index || c == validMatrixRowIndexColInfo.Index || c == validData.Schema.Label.Index)) + c => c == validMatrixColumnIndexColInfo.Index || c == validMatrixRowIndexColInfo.Index || c == validData.Schema.Label.Value.Index)) { - ValueGetter validLabelGetter = RowCursorUtils.GetGetterAs(NumberType.R4, validCursor, validData.Schema.Label.Index); + ValueGetter validLabelGetter = RowCursorUtils.GetGetterAs(NumberType.R4, validCursor, validData.Schema.Label.Value.Index); var validMatrixColumnIndexGetter = RowCursorUtils.GetGetterAs(NumberType.U4, validCursor, validMatrixColumnIndexColInfo.Index); var validMatrixRowIndexGetter = RowCursorUtils.GetGetterAs(NumberType.U4, validCursor, validMatrixRowIndexColInfo.Index); diff --git a/src/Microsoft.ML.Recommender/RecommenderUtils.cs b/src/Microsoft.ML.Recommender/RecommenderUtils.cs index 607a149cdc..d683bad3ad 100644 --- a/src/Microsoft.ML.Recommender/RecommenderUtils.cs +++ b/src/Microsoft.ML.Recommender/RecommenderUtils.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Threading; +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; @@ -12,14 +12,14 @@ internal static class RecommenderUtils { /// /// Check if the considered data, , contains column roles specified by and . - /// If the column roles, and , uniquely exist in data, their would be assigned + /// If the column roles, and , uniquely exist in data, their would be assigned /// to the two out parameters below. /// /// The considered data being checked - /// The column as role row index in the input data - /// The column as role column index in the input data + /// The schema column as the row in the input data + /// The schema column as the column in the input data /// Whether a non-user error should be thrown as a decode - public static void CheckAndGetMatrixIndexColumns(RoleMappedData data, out ColumnInfo matrixColumnIndexColumn, out ColumnInfo matrixRowIndexColumn, bool isDecode) + public static void CheckAndGetMatrixIndexColumns(RoleMappedData data, out Schema.Column matrixColumnIndexColumn, out Schema.Column matrixRowIndexColumn, bool isDecode) { Contracts.AssertValue(data); CheckRowColumnType(data, MatrixColumnIndexKind, out matrixColumnIndexColumn, isDecode); @@ -38,15 +38,15 @@ private static bool TryMarshalGoodRowColumnType(ColumnType type, out KeyType key } /// - /// Checks whether a column kind in a RoleMappedData is unique, and its type - /// is a U4 key of known cardinality. + /// Checks whether a column kind in a is unique, and its type + /// is a key of known cardinality. /// /// The training examples /// The column role to try to extract - /// The extracted column info + /// The extracted schema column /// Whether a non-user error should be thrown as a decode /// The type cast to a key-type - private static KeyType CheckRowColumnType(RoleMappedData data, RoleMappedSchema.ColumnRole role, out ColumnInfo info, bool isDecode) + private static KeyType CheckRowColumnType(RoleMappedData data, RoleMappedSchema.ColumnRole role, out Schema.Column col, bool isDecode) { Contracts.AssertValue(data); Contracts.AssertValue(role.Value); @@ -59,17 +59,17 @@ private static KeyType CheckRowColumnType(RoleMappedData data, RoleMappedSchema. throw Contracts.ExceptDecode(format2, role.Value, kindCount); throw Contracts.Except(format2, role.Value, kindCount); } - info = data.Schema.GetColumns(role)[0]; + col = data.Schema.GetColumns(role)[0]; // REVIEW tfinley: Should we be a bit less restrictive? This doesn't seem like // too terrible of a restriction. const string format = "Column '{0}' with role {1} should be a known cardinality U4 key, but is instead '{2}'"; KeyType keyType; - if (!TryMarshalGoodRowColumnType(info.Type, out keyType)) + if (!TryMarshalGoodRowColumnType(col.Type, out keyType)) { if (isDecode) - throw Contracts.ExceptDecode(format, info.Name, role.Value, info.Type); - throw Contracts.Except(format, info.Name, role.Value, info.Type); + throw Contracts.ExceptDecode(format, col.Name, role.Value, col.Type); + throw Contracts.Except(format, col.Name, role.Value, col.Type); } return keyType; } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index adcd598019..a99f74d268 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -246,7 +246,7 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no int latentDimAligned, AlignedArray latentSum, int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, VBuffer buffer, ref long badExampleCount) { var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); - Func pred = c => featureColumns.Select(ci => ci.Index).Contains(c) || c == data.Schema.Label.Index || (data.Schema.Weight != null && c == data.Schema.Weight.Index); + Func pred = c => featureColumns.Select(ci => ci.Index).Contains(c) || c == data.Schema.Label.Value.Index || c == data.Schema.Weight?.Index; var getters = new ValueGetter>[featureColumns.Count]; float label = 0; float weight = 1; @@ -257,8 +257,8 @@ private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool no int count = 0; using (var cursor = data.Data.GetRowCursor(pred)) { - var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Index); ; - var weightGetter = data.Schema.Weight == null ? null : cursor.GetGetter(data.Schema.Weight.Index); + var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index); + var weightGetter = data.Schema.Weight?.Index is int weightIdx ? cursor.GetGetter(weightIdx) : null; for (int f = 0; f < featureColumns.Count; f++) getters[f] = cursor.GetGetter>(featureColumns[f].Index); while (cursor.MoveNext()) @@ -300,9 +300,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress for (int f = 0; f < fieldCount; f++) { var col = featureColumns[f]; - if (col == null) - throw ch.ExceptParam(nameof(data), "Empty feature column not allowed"); - Host.Assert(!data.Schema.Schema[col.Index].IsHidden); + Host.Assert(!col.IsHidden); if (!(col.Type is VectorType vectorType) || !vectorType.IsKnownSizeVector || vectorType.ItemType != NumberType.Float) @@ -323,7 +321,12 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress var validFeatureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); Host.Assert(fieldCount == validFeatureColumns.Count); for (int f = 0; f < fieldCount; f++) - Host.Assert(featureColumns[f] == validFeatureColumns[f]); + { + var featCol = featureColumns[f]; + var validFeatCol = validFeatureColumns[f]; + Host.Assert(featCol.Name == validFeatCol.Name); + Host.Assert(featCol.Type == validFeatCol.Type); + } } bool shuffle = _shuffle; if (shuffle && !data.Data.CanShuffle) @@ -352,7 +355,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress entry.SetProgress(0, iter, _numIterations); entry.SetProgress(1, exampleCount); }); - Func pred = c => fieldColumnIndexes.Contains(c) || c == data.Schema.Label.Index || (data.Schema.Weight != null && c == data.Schema.Weight.Index); + Func pred = c => fieldColumnIndexes.Contains(c) || c == data.Schema.Label.Value.Index || c == data.Schema.Weight?.Index; InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights, out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned); @@ -361,8 +364,8 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress { using (var cursor = data.Data.GetRowCursor(pred, rng)) { - var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Index); - var weightGetter = data.Schema.Weight == null ? null : RowCursorUtils.GetGetterAs(NumberType.R4, cursor, data.Schema.Weight.Index); + var labelGetter = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index); + var weightGetter = data.Schema.Weight?.Index is int weightIdx ? RowCursorUtils.GetGetterAs(NumberType.R4, cursor, weightIdx) : null; for (int i = 0; i < fieldCount; i++) featureGetters[i] = cursor.GetGetter>(fieldColumnIndexes[i]); loss = 0; diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index 7e7dddf4d8..fabadcabca 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -66,7 +66,7 @@ internal sealed class FieldAwareFactorizationMachineScalarRowMapper : ISchemaBou public ISchemaBindableMapper Bindable => _pred; - private readonly ColumnInfo[] _columns; + private readonly Schema.Column[] _columns; private readonly List _inputColumnIndexes; private readonly IHostEnvironment _env; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs index 5ae08ea4dd..cbfc68dcea 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs @@ -208,12 +208,12 @@ private protected override void ComputeTrainingStatistics(IChannel ch, FloatLabe ch.Info("AIC: \t{0}", 2 * numParams + deviance); // Show the coefficients statistics table. - var featureColIdx = cursorFactory.Data.Schema.Feature.Index; + var featureCol = cursorFactory.Data.Schema.Feature.Value; var schema = cursorFactory.Data.Data.Schema; var featureLength = CurrentWeights.Length - BiasCount; var namesSpans = VBufferUtils.CreateEmpty>(featureLength); - if (schema[featureColIdx].HasSlotNames(featureLength)) - schema[featureColIdx].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref namesSpans); + if (featureCol.HasSlotNames(featureLength)) + featureCol.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref namesSpans); Host.Assert(namesSpans.Length == featureLength); // Inverse mapping of non-zero weight slots. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 0be38f3547..e666ef90a5 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -126,18 +126,15 @@ private protected override void CheckLabel(RoleMappedData data) _prior = new Double[_numClasses]; // Try to get the label key values metedata. - var schema = data.Data.Schema; - var labelIdx = data.Schema.Label.Index; - var labelMetadataType = schema[labelIdx].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; - if (labelMetadataType == null || !labelMetadataType.IsKnownSizeVector || !labelMetadataType.ItemType.IsText || - labelMetadataType.VectorSize != _numClasses) + var labelCol = data.Schema.Label.Value; + var labelMetadataType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type; + if (!(labelMetadataType is VectorType vecType && vecType.ItemType == TextType.Instance && vecType.Size == _numClasses)) { _labelNames = null; return; } - VBuffer> labelNames = default; - schema[labelIdx].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); + labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames); // If label names is not dense or contain NA or default value, then it follows that // at least one class does not have a valid name for its label. If the label names we diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 9b71f7a7f6..1f2905dc72 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -97,9 +97,9 @@ private protected IDataView MapLabelsCore(ColumnType type, InPredicate equ Host.Assert(type.RawType == typeof(T)); Host.AssertValue(equalsTarget); Host.AssertValue(data); - Host.AssertValue(data.Schema.Label); + Host.Assert(data.Schema.Label.HasValue); - var lab = data.Schema.Label; + var lab = data.Schema.Label.Value; InPredicate isMissing; if (!Args.ImputeMissingLabelsAsNegative && Conversions.Instance.TryGetIsNAPredicate(type, out isMissing)) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 66294fd240..7c12a310ac 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -93,16 +93,17 @@ private protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainCon { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; - Host.Check(data.Schema.Label != null, "Missing Label column"); - Host.Check(data.Schema.Label.Type == NumberType.Float || data.Schema.Label.Type is KeyType, + Host.Check(data.Schema.Label.HasValue, "Missing Label column"); + var labelCol = data.Schema.Label.Value; + Host.Check(labelCol.Type == NumberType.Float || labelCol.Type is KeyType, "Invalid type for Label column, only floats and known-size keys are supported"); - Host.Check(data.Schema.Feature != null, "Missing Feature column"); + Host.Check(data.Schema.Feature.HasValue, "Missing Feature column"); int featureCount; data.CheckFeatureFloatVector(out featureCount); int labelCount = 0; - if (data.Schema.Label.Type.IsKey) - labelCount = data.Schema.Label.Type.KeyCount; + if (labelCol.Type is KeyType labelKeyType) + labelCount = labelKeyType.Count; int[] labelHistogram = new int[labelCount]; int[][] featureHistogram = new int[labelCount][]; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index a3bf1fb5a2..c3232c3dd9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -118,7 +118,7 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel { var view = MapLabels(data, cls); - string trainerLabel = data.Schema.Label.Name; + string trainerLabel = data.Schema.Label.Value.Name; // REVIEW: In principle we could support validation sets and the like via the train context, but // this is currently unsupported. @@ -144,8 +144,8 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel private IDataView MapLabels(RoleMappedData data, int cls) { - var lab = data.Schema.Label; - Host.Assert(!data.Schema.Schema[lab.Index].IsHidden); + var lab = data.Schema.Label.Value; + Host.Assert(!lab.IsHidden); Host.Assert(lab.Type.KeyCount > 0 || lab.Type == NumberType.R4 || lab.Type == NumberType.R8); if (lab.Type.KeyCount > 0) diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index cc1a2379a6..5aa1fc70e0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -127,7 +127,7 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel ch { // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the // MetaMulticlassTrainer constructor. - string trainerLabel = data.Schema.Label.Name; + string trainerLabel = data.Schema.Label.Value.Name; var view = MapLabels(data, cls1, cls2); var transformer = trainer.Fit(view); @@ -144,8 +144,8 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel ch private IDataView MapLabels(RoleMappedData data, int cls1, int cls2) { - var lab = data.Schema.Label; - Host.Assert(!data.Schema.Schema[lab.Index].IsHidden); + var lab = data.Schema.Label.Value; + Host.Assert(!lab.IsHidden); Host.Assert(lab.Type.KeyCount > 0 || lab.Type == NumberType.R4 || lab.Type == NumberType.R8); if (lab.Type.KeyCount > 0) diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index 1c6645a78f..fbb5bb5580 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -120,13 +120,12 @@ private protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, Ro var roles = examples.Schema.GetColumnRoleNames(); var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); - ch.AssertValue(examplesToFeedTrain.Schema.Label); - ch.AssertValue(examplesToFeedTrain.Schema.Feature); - if (examples.Schema.Weight != null) - ch.AssertValue(examplesToFeedTrain.Schema.Weight); + ch.Assert(examplesToFeedTrain.Schema.Label.HasValue); + ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue); + if (examples.Schema.Weight.HasValue) + ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue); - int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; - ch.Check(numFeatures > 0, "Training set has no features, aborting training."); + ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training."); return examplesToFeedTrain; } @@ -287,7 +286,7 @@ private protected sealed override TModel TrainCore(IChannel ch, RoleMappedData d Contracts.Assert(predictor == null, "SDCA based trainers don't support continuous training."); Contracts.Assert(weightSetCount >= 1); - int numFeatures = data.Schema.Feature.Type.VectorSize; + int numFeatures = data.Schema.Feature.Value.Type.VectorSize; long maxTrainingExamples = MaxDualTableSize / weightSetCount; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight | CursOpt.Id); int numThreads; @@ -1761,8 +1760,9 @@ private protected override TScalarPredictor TrainCore(IChannel ch, RoleMappedDat Contracts.AssertValue(data); Contracts.Assert(weightSetCount == 1); Contracts.AssertValueOrNull(predictor); + Contracts.Assert(data.Schema.Feature.HasValue); - int numFeatures = data.Schema.Feature.Type.VectorSize; + int numFeatures = data.Schema.Feature.Value.Type.VectorSize; var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features | CursOpt.Weight); int numThreads; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs index 410acf072c..722e20109b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs @@ -278,19 +278,20 @@ public BinaryPredictionTransformer Fit(IDataView input) private protected override PriorPredictor Train(TrainContext context) { - Contracts.CheckValue(context, nameof(context)); + Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; data.CheckBinaryLabel(); - Contracts.CheckParam(data.Schema.Label != null, nameof(data), "Missing Label column"); - Contracts.CheckParam(data.Schema.Label.Type == NumberType.Float, nameof(data), "Invalid type for Label column"); + Host.CheckParam(data.Schema.Label.HasValue, nameof(data), "Missing Label column"); + var labelCol = data.Schema.Label.Value; + Host.CheckParam(labelCol.Type == NumberType.Float, nameof(data), "Invalid type for Label column"); double pos = 0; double neg = 0; - int col = data.Schema.Label.Index; + int col = labelCol.Index; int colWeight = -1; if (data.Schema.Weight?.Type == NumberType.Float) - colWeight = data.Schema.Weight.Index; + colWeight = data.Schema.Weight.Value.Index; using (var cursor = data.Data.GetRowCursor(c => c == col || c == colWeight)) { var getLab = cursor.GetLabelFloatGetter(data); diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index d27ce8791c..e4e00d5507 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -85,13 +85,12 @@ private protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, Ro var roles = examples.Schema.GetColumnRoleNames(); var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles); - ch.AssertValue(examplesToFeedTrain.Schema.Label); - ch.AssertValue(examplesToFeedTrain.Schema.Feature); - if (examples.Schema.Weight != null) - ch.AssertValue(examplesToFeedTrain.Schema.Weight); + ch.Assert(examplesToFeedTrain.Schema.Label.HasValue); + ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue); + if (examples.Schema.Weight.HasValue) + ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue); - int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize; - ch.Check(numFeatures > 0, "Training set has no features, aborting training."); + ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training."); return examplesToFeedTrain; } diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index b9534e8133..da22cb466e 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -1225,12 +1225,14 @@ internal override void Train(FixedSizeQueue data) /// The training time-series. internal override void Train(RoleMappedData data) { - _host.CheckParam(data != null, nameof(data), "The input series for training cannot be null."); - if (data.Schema.Feature.Type != NumberType.Float) - throw _host.ExceptUserArg(nameof(data.Schema.Feature.Name), "The feature column has type '{0}', but must be a float.", data.Schema.Feature.Type); + _host.CheckValue(data, nameof(data)); + _host.CheckParam(data.Schema.Feature.HasValue, nameof(data), "Must have features column."); + var featureCol = data.Schema.Feature.Value; + if (featureCol.Type != NumberType.Float) + throw _host.ExceptSchemaMismatch(nameof(data), "feature", featureCol.Name, "R4", featureCol.Type.ToString()); Single[] dataArray = new Single[_trainSize]; - int col = data.Schema.Feature.Index; + int col = featureCol.Index; int count = 0; using (var cursor = data.Data.GetRowCursor(c => c == col))