Skip to content

Commit 53c2a15

Browse files
authored
Normalization API helpers (#446)
* API conveniences for the Normalize transform
1 parent 12e6298 commit 53c2a15

File tree

10 files changed

+157
-100
lines changed

10 files changed

+157
-100
lines changed

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,22 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
335335
&& type.ItemType.IsText;
336336
}
337337

338+
/// <summary>
339+
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> metadata set to true.
340+
/// That metadata should be set when the data has undergone transforms that would render it
341+
/// "normalized."
342+
/// </summary>
343+
/// <param name="schema">The schema to query</param>
344+
/// <param name="col">Which column in the schema to query</param>
345+
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> metadata
346+
/// set to the scalar value <see cref="DvBool.True"/></returns>
347+
public static bool IsNormalized(this ISchema schema, int col)
348+
{
349+
Contracts.CheckValue(schema, nameof(schema));
350+
var value = default(DvBool);
351+
return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue;
352+
}
353+
338354
/// <summary>
339355
/// Tries to get the metadata kind of the specified type for a column.
340356
/// </summary>
@@ -347,6 +363,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
347363
/// <returns>True if the metadata of the right type exists, false otherwise</returns>
348364
public static bool TryGetMetadata<T>(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
349365
{
366+
Contracts.CheckValue(schema, nameof(schema));
367+
Contracts.CheckValue(type, nameof(type));
368+
350369
var metadataType = schema.GetMetadataTypeOrNull(kind, col);
351370
if (!type.Equals(metadataType))
352371
return false;
@@ -363,7 +382,7 @@ public static bool IsHidden(this ISchema schema, int col)
363382
string name = schema.GetColumnName(col);
364383
int top;
365384
bool tmp = schema.TryGetColumnIndex(name, out top);
366-
Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?");
385+
Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy.
367386
return !tmp || top != col;
368387
}
369388

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
namespace Microsoft.ML.Runtime.Data
1010
{
1111
/// <summary>
12-
/// This contains information about a column in an IDataView. It is essentially a convenience
13-
/// cache containing the name, column index, and column type for the column.
12+
/// This contains information about a column in an <see cref="IDataView"/>. It is essentially a convenience cache
13+
/// containing the name, column index, and column type for the column. The intended usage is that users of <see cref="RoleMappedSchema"/>
14+
/// will have a convenient method of getting the index and type without having to separately query it through the <see cref="ISchema"/>,
15+
/// since practically the first thing a consumer of a <see cref="RoleMappedSchema"/> will want to do once they get a mappping is get
16+
/// the type and index of the corresponding column.
1417
/// </summary>
1518
public sealed class ColumnInfo
1619
{
@@ -71,12 +74,20 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
7174
}
7275

7376
/// <summary>
74-
/// Encapsulates an ISchema plus column role mapping information. It has convenience fields for
75-
/// several common column roles, but can hold an arbitrary set of column infos. The convenience
76-
/// fields are non-null iff there is a unique column with the corresponding role. When there are
77-
/// no such columns or more than one such column, the field is null. The Has, HasUnique, and
78-
/// HasMultiple methods provide some cardinality information.
79-
/// Note that all columns assigned roles are guaranteed to be non-hidden in this schema.
77+
/// Encapsulates an <see cref="ISchema"/> plus column role mapping information. The purpose of role mappings is to
78+
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
79+
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
80+
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
81+
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
82+
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
83+
/// multiple features columns to consume that information.
84+
///
85+
/// This class has convenience fields for several common column roles (se.g., <see cref="Feature"/>, <see
86+
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is
87+
/// a unique column with the corresponding role. When there are no such columns or more than one such column, the
88+
/// field is null. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/> methods provide
89+
/// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this
90+
/// schema.
8091
/// </summary>
8192
public sealed class RoleMappedSchema
8293
{
@@ -85,18 +96,16 @@ public sealed class RoleMappedSchema
8596
private const string GroupString = "Group";
8697
private const string WeightString = "Weight";
8798
private const string NameString = "Name";
88-
private const string IdString = "Id";
8999
private const string FeatureContributionsString = "FeatureContributions";
90100

91101
public struct ColumnRole
92102
{
93-
public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } }
94-
public static ColumnRole Label { get { return new ColumnRole(LabelString); } }
95-
public static ColumnRole Group { get { return new ColumnRole(GroupString); } }
96-
public static ColumnRole Weight { get { return new ColumnRole(WeightString); } }
97-
public static ColumnRole Name { get { return new ColumnRole(NameString); } }
98-
public static ColumnRole Id { get { return new ColumnRole(IdString); } }
99-
public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } }
103+
public static ColumnRole Feature => FeatureString;
104+
public static ColumnRole Label => LabelString;
105+
public static ColumnRole Group => GroupString;
106+
public static ColumnRole Weight => WeightString;
107+
public static ColumnRole Name => NameString;
108+
public static ColumnRole FeatureContributions => FeatureContributionsString;
100109

101110
public readonly string Value;
102111

@@ -152,11 +161,6 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
152161
/// </summary>
153162
public readonly ColumnInfo Name;
154163

155-
/// <summary>
156-
/// The Id column, when there is exactly one (null otherwise).
157-
/// </summary>
158-
public readonly ColumnInfo Id;
159-
160164
// Maps from role to the associated column infos.
161165
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;
162166

@@ -194,9 +198,6 @@ private RoleMappedSchema(ISchema schema, Dictionary<string, IReadOnlyList<Column
194198
case NameString:
195199
Name = cols[0];
196200
break;
197-
case IdString:
198-
Id = cols[0];
199-
break;
200201
}
201202
}
202203
}
@@ -224,8 +225,8 @@ private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole rol
224225

225226
private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
226227
{
227-
Contracts.AssertValue(schema, "schema");
228-
Contracts.AssertValue(roles, "roles");
228+
Contracts.AssertValue(schema);
229+
Contracts.AssertValue(roles);
229230

230231
var map = new Dictionary<string, List<ColumnInfo>>();
231232
foreach (var kvp in roles)
@@ -241,8 +242,8 @@ private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema,
241242

242243
private static Dictionary<string, List<ColumnInfo>> MapFromNamesOpt(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
243244
{
244-
Contracts.AssertValue(schema, "schema");
245-
Contracts.AssertValue(roles, "roles");
245+
Contracts.AssertValue(schema);
246+
Contracts.AssertValue(roles);
246247

247248
var map = new Dictionary<string, List<ColumnInfo>>();
248249
foreach (var kvp in roles)
@@ -334,6 +335,13 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
334335
}
335336
}
336337

338+
/// <summary>
339+
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
340+
/// exactly one such mapping, and otherwise throws an exception.
341+
/// </summary>
342+
/// <param name="role">The role to look up</param>
343+
/// <returns>The info corresponding to that role, assuming there was only one column
344+
/// mapped to that</returns>
337345
public ColumnInfo GetUniqueColumn(ColumnRole role)
338346
{
339347
var infos = GetColumns(role);
@@ -398,9 +406,9 @@ public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable<KeyValuePai
398406
}
399407

400408
/// <summary>
401-
/// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the
402-
/// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is,
403-
/// Data.Schema == Schema.Schema.
409+
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
410+
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
411+
/// guaranteed to equal the the <see cref="ISchematized.Schema"/> of <see cref="Data"/>.
404412
/// </summary>
405413
public sealed class RoleMappedData
406414
{

src/Microsoft.ML.Data/Commands/TrainCommand.cs

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,11 @@ private static List<IDataTransform> BacktrackPipe(IDataView dataPipe, out IDataV
468468
Contracts.AssertValue(dataPipe);
469469

470470
var transforms = new List<IDataTransform>();
471-
while (true)
471+
while (dataPipe is IDataTransform xf)
472472
{
473473
// REVIEW: a malicious user could construct a loop in the Source chain, that would
474474
// cause this method to iterate forever (and throw something when the list overflows). There's
475475
// no way to insulate from ALL malicious behavior.
476-
477-
var xf = dataPipe as IDataTransform;
478-
if (xf == null)
479-
break;
480-
481476
transforms.Add(xf);
482477
dataPipe = xf.Source;
483478
Contracts.AssertValue(dataPipe);
@@ -514,11 +509,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
514509
{
515510
if (autoNorm != NormalizeOption.Yes)
516511
{
517-
var nn = trainer as ITrainerEx;
518512
DvBool isNormalized = DvBool.False;
519-
if (nn == null || !nn.NeedNormalization ||
520-
(schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
521-
isNormalized.IsTrue))
513+
if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
522514
{
523515
ch.Info("Not adding a normalizer.");
524516
return false;
@@ -530,20 +522,13 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
530522
}
531523
}
532524
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
533-
// Quote the feature column name
534-
string quotedFeatureColumnName = featureColumn;
535-
StringBuilder sb = new StringBuilder();
536-
if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
537-
quotedFeatureColumnName = sb.ToString();
538-
var component = new SubComponent<IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
539-
var loader = view as IDataLoader;
540-
if (loader != null)
541-
{
542-
view = CompositeDataLoader.Create(env, loader,
543-
new KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>(null, component));
544-
}
525+
IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
526+
=> NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);
527+
528+
if (view is IDataLoader loader)
529+
view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
545530
else
546-
view = component.CreateInstance(env, view);
531+
view = ApplyNormalizer(env, view);
547532
return true;
548533
}
549534
return false;

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public sealed class Arguments
4141
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
4242
}
4343

44-
internal struct TransformEx
44+
private struct TransformEx
4545
{
4646
public readonly string Tag;
4747
public readonly string ArgsString;
@@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo()
7878
// The composition of loader plus transforms in order.
7979
private readonly IDataLoader _loader;
8080
private readonly TransformEx[] _transforms;
81-
private readonly IDataView _view;
8281
private readonly ITransposeDataView _tview;
83-
private readonly ITransposeSchema _tschema;
8482
private readonly IHost _host;
8583

8684
/// <summary>
8785
/// Returns the underlying data view of the composite loader.
8886
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
8987
/// </summary>
90-
internal IDataView View { get { return _view; } }
88+
internal IDataView View { get; }
9189

9290
/// <summary>
9391
/// Creates a loader according to the specified <paramref name="args"/>.
@@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader
200198
IDataLoader pipeStart;
201199
if (composite != null)
202200
{
203-
srcView = composite._view;
201+
srcView = composite.View;
204202
exes.AddRange(composite._transforms);
205203
pipeStart = composite._loader;
206204
}
@@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
409407
_host = host;
410408
_host.AssertNonEmpty(transforms);
411409

412-
_view = transforms[transforms.Length - 1].Transform;
413-
_tview = _view as ITransposeDataView;
414-
_tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema;
410+
View = transforms[transforms.Length - 1].Transform;
411+
_tview = View as ITransposeDataView;
412+
TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
415413

416414
var srcLoader = transforms[0].Transform.Source as IDataLoader;
417415

@@ -561,43 +559,34 @@ private static string GenerateTag(int index)
561559

562560
public long? GetRowCount(bool lazy = true)
563561
{
564-
return _view.GetRowCount(lazy);
562+
return View.GetRowCount(lazy);
565563
}
566564

567-
public bool CanShuffle
568-
{
569-
get { return _view.CanShuffle; }
570-
}
565+
public bool CanShuffle => View.CanShuffle;
571566

572-
public ISchema Schema
573-
{
574-
get { return _view.Schema; }
575-
}
567+
public ISchema Schema => View.Schema;
576568

577-
public ITransposeSchema TransposeSchema
578-
{
579-
get { return _tschema; }
580-
}
569+
public ITransposeSchema TransposeSchema { get; }
581570

582571
public IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)
583572
{
584573
_host.CheckValue(predicate, nameof(predicate));
585574
_host.CheckValueOrNull(rand);
586-
return _view.GetRowCursor(predicate, rand);
575+
return View.GetRowCursor(predicate, rand);
587576
}
588577

589578
public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
590579
Func<int, bool> predicate, int n, IRandom rand = null)
591580
{
592581
_host.CheckValue(predicate, nameof(predicate));
593582
_host.CheckValueOrNull(rand);
594-
return _view.GetRowCursorSet(out consolidator, predicate, n, rand);
583+
return View.GetRowCursorSet(out consolidator, predicate, n, rand);
595584
}
596585

597586
public ISlotCursor GetSlotCursor(int col)
598587
{
599588
_host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
600-
if (_tschema == null || _tschema.GetSlotType(col) == null)
589+
if (TransposeSchema?.GetSlotType(col) == null)
601590
{
602591
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
603592
Schema.GetColumnName(col));

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
245245
{
246246
// All meta-data is passed through in this case, so don't need the slot names type.
247247
echoSrc[i] = true;
248-
DvBool b = DvBool.False;
249248
isNormalized[i] =
250-
info.SrcTypes[0].ItemType.IsNumber &&
251-
Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) &&
252-
b.IsTrue;
249+
info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]);
253250
types[i] = info.SrcTypes[0];
254251
continue;
255252
}
@@ -260,9 +257,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
260257
{
261258
foreach (var srcCol in info.SrcIndices)
262259
{
263-
DvBool b = DvBool.False;
264-
if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) ||
265-
!b.IsTrue)
260+
if (!Input.IsNormalized(srcCol))
266261
{
267262
isNormalized[i] = false;
268263
break;

0 commit comments

Comments
 (0)