Skip to content

Commit 007869c

Browse files
author
Tom Finley
committed
API conveniences for the Normalize transform
1 parent 9e0a4ba commit 007869c

File tree

10 files changed

+156
-90
lines changed

10 files changed

+156
-90
lines changed

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
335335
&& type.ItemType.IsText;
336336
}
337337

338+
/// <summary>
339+
/// Returns whether a column has the <see cref="Kinds.IsNormalized"/> metadata set to true.
340+
/// </summary>
341+
/// <param name="schema">The schema to query</param>
342+
/// <param name="col">Which column in the schema to query</param>
343+
/// <returns>True if and only if the column has the <see cref="Kinds.IsNormalized"/> metadata
344+
/// set to the scalar value <see cref="DvBool.True"/></returns>
345+
public static bool IsNormalized(this ISchema schema, int col)
346+
{
347+
Contracts.CheckValue(schema, nameof(schema));
348+
var value = default(DvBool);
349+
return schema.TryGetMetadata(BoolType.Instance, Kinds.IsNormalized, col, ref value) && value.IsTrue;
350+
}
351+
338352
/// <summary>
339353
/// Tries to get the metadata kind of the specified type for a column.
340354
/// </summary>
@@ -347,6 +361,9 @@ public static bool HasKeyNames(this ISchema schema, int col, int keyCount)
347361
/// <returns>True if the metadata of the right type exists, false otherwise</returns>
348362
public static bool TryGetMetadata<T>(this ISchema schema, PrimitiveType type, string kind, int col, ref T value)
349363
{
364+
Contracts.CheckValue(schema, nameof(schema));
365+
Contracts.CheckValue(type, nameof(type));
366+
350367
var metadataType = schema.GetMetadataTypeOrNull(kind, col);
351368
if (!type.Equals(metadataType))
352369
return false;
@@ -363,7 +380,7 @@ public static bool IsHidden(this ISchema schema, int col)
363380
string name = schema.GetColumnName(col);
364381
int top;
365382
bool tmp = schema.TryGetColumnIndex(name, out top);
366-
Contracts.Assert(tmp, "Why did TryGetColumnIndex return false?");
383+
Contracts.Assert(tmp); // This would only be false if the implementation of schema were buggy.
367384
return !tmp || top != col;
368385
}
369386

src/Microsoft.ML.Core/Data/RoleMappedSchema.cs

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
namespace Microsoft.ML.Runtime.Data
1010
{
1111
/// <summary>
12-
/// This contains information about a column in an IDataView. It is essentially a convenience
13-
/// cache containing the name, column index, and column type for the column.
12+
/// This contains information about a column in an <see cref="IDataView"/>. It is essentially a convenience cache
13+
/// containing the name, column index, and column type for the column. The intended usage is that users of <see cref="RoleMappedSchema"/>
14+
/// to get the column index and type associated with
1415
/// </summary>
1516
public sealed class ColumnInfo
1617
{
@@ -71,12 +72,20 @@ public static ColumnInfo CreateFromIndex(ISchema schema, int index)
7172
}
7273

7374
/// <summary>
74-
/// Encapsulates an ISchema plus column role mapping information. It has convenience fields for
75-
/// several common column roles, but can hold an arbitrary set of column infos. The convenience
76-
/// fields are non-null iff there is a unique column with the corresponding role. When there are
77-
/// no such columns or more than one such column, the field is null. The Has, HasUnique, and
78-
/// HasMultiple methods provide some cardinality information.
79-
/// Note that all columns assigned roles are guaranteed to be non-hidden in this schema.
75+
/// Encapsulates an <see cref="ISchema"/> plus column role mapping information. The purpose of role mappings is to
76+
/// provide information on what the intended usage is for. That is: while a given data view may have a column named
77+
/// "Features", by itself that is insufficient: the trainer must be fed a role mapping that says that the role
78+
/// mapping for features is filled by that "Features" column. This allows things like columns not named "Features"
79+
/// to actually fill that role (as opposed to insisting on a hard coding, or having every trainer have to be
80+
/// individually configured). Also, by being a one-to-many mapping, it is a way for learners that can consume
81+
/// multiple features columns to consume that information.
82+
///
83+
/// This class has convenience fields for several common column roles (se.g., <see cref="Feature"/>, <see
84+
/// cref="Label"/>), but can hold an arbitrary set of column infos. The convenience fields are non-null iff there is
85+
/// a unique column with the corresponding role. When there are no such columns or more than one such column, the
86+
/// field is null. The <see cref="Has"/>, <see cref="HasUnique"/>, and <see cref="HasMultiple"/> methods provide
87+
/// some cardinality information. Note that all columns assigned roles are guaranteed to be non-hidden in this
88+
/// schema.
8089
/// </summary>
8190
public sealed class RoleMappedSchema
8291
{
@@ -85,18 +94,16 @@ public sealed class RoleMappedSchema
8594
private const string GroupString = "Group";
8695
private const string WeightString = "Weight";
8796
private const string NameString = "Name";
88-
private const string IdString = "Id";
8997
private const string FeatureContributionsString = "FeatureContributions";
9098

9199
public struct ColumnRole
92100
{
93-
public static ColumnRole Feature { get { return new ColumnRole(FeatureString); } }
94-
public static ColumnRole Label { get { return new ColumnRole(LabelString); } }
95-
public static ColumnRole Group { get { return new ColumnRole(GroupString); } }
96-
public static ColumnRole Weight { get { return new ColumnRole(WeightString); } }
97-
public static ColumnRole Name { get { return new ColumnRole(NameString); } }
98-
public static ColumnRole Id { get { return new ColumnRole(IdString); } }
99-
public static ColumnRole FeatureContributions { get { return new ColumnRole(FeatureContributionsString); } }
101+
public static ColumnRole Feature => FeatureString;
102+
public static ColumnRole Label => LabelString;
103+
public static ColumnRole Group => GroupString;
104+
public static ColumnRole Weight => WeightString;
105+
public static ColumnRole Name => NameString;
106+
public static ColumnRole FeatureContributions => FeatureContributionsString;
100107

101108
public readonly string Value;
102109

@@ -152,11 +159,6 @@ public static KeyValuePair<ColumnRole, string> CreatePair(ColumnRole role, strin
152159
/// </summary>
153160
public readonly ColumnInfo Name;
154161

155-
/// <summary>
156-
/// The Id column, when there is exactly one (null otherwise).
157-
/// </summary>
158-
public readonly ColumnInfo Id;
159-
160162
// Maps from role to the associated column infos.
161163
private readonly Dictionary<string, IReadOnlyList<ColumnInfo>> _map;
162164

@@ -194,9 +196,6 @@ private RoleMappedSchema(ISchema schema, Dictionary<string, IReadOnlyList<Column
194196
case NameString:
195197
Name = cols[0];
196198
break;
197-
case IdString:
198-
Id = cols[0];
199-
break;
200199
}
201200
}
202201
}
@@ -224,8 +223,8 @@ private static void Add(Dictionary<string, List<ColumnInfo>> map, ColumnRole rol
224223

225224
private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
226225
{
227-
Contracts.AssertValue(schema, "schema");
228-
Contracts.AssertValue(roles, "roles");
226+
Contracts.AssertValue(schema);
227+
Contracts.AssertValue(roles);
229228

230229
var map = new Dictionary<string, List<ColumnInfo>>();
231230
foreach (var kvp in roles)
@@ -241,8 +240,8 @@ private static Dictionary<string, List<ColumnInfo>> MapFromNames(ISchema schema,
241240

242241
private static Dictionary<string, List<ColumnInfo>> MapFromNamesOpt(ISchema schema, IEnumerable<KeyValuePair<ColumnRole, string>> roles)
243242
{
244-
Contracts.AssertValue(schema, "schema");
245-
Contracts.AssertValue(roles, "roles");
243+
Contracts.AssertValue(schema);
244+
Contracts.AssertValue(roles);
246245

247246
var map = new Dictionary<string, List<ColumnInfo>>();
248247
foreach (var kvp in roles)
@@ -334,6 +333,13 @@ public IEnumerable<KeyValuePair<ColumnRole, string>> GetColumnRoleNames(ColumnRo
334333
}
335334
}
336335

336+
/// <summary>
337+
/// Returns the <see cref="ColumnInfo"/> corresponding to <paramref name="role"/> if there is
338+
/// exactly one such mapping, and otherwise throws an exception.
339+
/// </summary>
340+
/// <param name="role">The role to look up</param>
341+
/// <returns>The info corresponding to that role, assuming there was only one column
342+
/// mapped to that</returns>
337343
public ColumnInfo GetUniqueColumn(ColumnRole role)
338344
{
339345
var infos = GetColumns(role);
@@ -398,9 +404,9 @@ public static RoleMappedSchema CreateOpt(ISchema schema, IEnumerable<KeyValuePai
398404
}
399405

400406
/// <summary>
401-
/// Encapsulates an IDataView plus a corresponding RoleMappedSchema. Note that the schema of the
402-
/// RoleMappedSchema is guaranteed to be the same schema of the IDataView, that is,
403-
/// Data.Schema == Schema.Schema.
407+
/// Encapsulates an <see cref="IDataView"/> plus a corresponding <see cref="RoleMappedSchema"/>.
408+
/// Note that the schema of <see cref="RoleMappedSchema.Schema"/> of <see cref="Schema"/> is
409+
/// guaranteed to equal the the <see cref="ISchematized.Schema"/> of <see cref="Data"/>.
404410
/// </summary>
405411
public sealed class RoleMappedData
406412
{

src/Microsoft.ML.Data/Commands/TrainCommand.cs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,8 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
514514
{
515515
if (autoNorm != NormalizeOption.Yes)
516516
{
517-
var nn = trainer as ITrainerEx;
518517
DvBool isNormalized = DvBool.False;
519-
if (nn == null || !nn.NeedNormalization ||
520-
(schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
521-
isNormalized.IsTrue))
518+
if (trainer.NeedNormalization() != true || schema.IsNormalized(featCol))
522519
{
523520
ch.Info("Not adding a normalizer.");
524521
return false;
@@ -530,20 +527,17 @@ public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITra
530527
}
531528
}
532529
ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
533-
// Quote the feature column name
534-
string quotedFeatureColumnName = featureColumn;
535-
StringBuilder sb = new StringBuilder();
536-
if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
537-
quotedFeatureColumnName = sb.ToString();
538-
var component = new SubComponent<IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
539-
var loader = view as IDataLoader;
540-
if (loader != null)
541-
{
542-
view = CompositeDataLoader.Create(env, loader,
543-
new KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>(null, component));
544-
}
530+
// REVIEW: This verbose constructor should be replaced with zeahmed's enhancements once #405 is committed.
531+
IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
532+
=> NormalizeTransform.Create(innerEnv, new NormalizeTransform.MinMaxArguments()
533+
{
534+
Column = new[] { new NormalizeTransform.AffineColumn { Source = featureColumn, Name = featureColumn } }
535+
}, input);
536+
537+
if (view is IDataLoader loader)
538+
view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
545539
else
546-
view = component.CreateInstance(env, view);
540+
view = ApplyNormalizer(env, view);
547541
return true;
548542
}
549543
return false;

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public sealed class Arguments
4141
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
4242
}
4343

44-
internal struct TransformEx
44+
private struct TransformEx
4545
{
4646
public readonly string Tag;
4747
public readonly string ArgsString;
@@ -78,16 +78,14 @@ private static VersionInfo GetVersionInfo()
7878
// The composition of loader plus transforms in order.
7979
private readonly IDataLoader _loader;
8080
private readonly TransformEx[] _transforms;
81-
private readonly IDataView _view;
8281
private readonly ITransposeDataView _tview;
83-
private readonly ITransposeSchema _tschema;
8482
private readonly IHost _host;
8583

8684
/// <summary>
8785
/// Returns the underlying data view of the composite loader.
8886
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
8987
/// </summary>
90-
internal IDataView View { get { return _view; } }
88+
internal IDataView View { get; }
9189

9290
/// <summary>
9391
/// Creates a loader according to the specified <paramref name="args"/>.
@@ -200,7 +198,7 @@ private static IDataLoader ApplyTransformsCore(IHost host, IDataLoader srcLoader
200198
IDataLoader pipeStart;
201199
if (composite != null)
202200
{
203-
srcView = composite._view;
201+
srcView = composite.View;
204202
exes.AddRange(composite._transforms);
205203
pipeStart = composite._loader;
206204
}
@@ -409,9 +407,9 @@ private CompositeDataLoader(IHost host, TransformEx[] transforms)
409407
_host = host;
410408
_host.AssertNonEmpty(transforms);
411409

412-
_view = transforms[transforms.Length - 1].Transform;
413-
_tview = _view as ITransposeDataView;
414-
_tschema = _tview == null ? new TransposerUtils.SimpleTransposeSchema(_view.Schema) : _tview.TransposeSchema;
410+
View = transforms[transforms.Length - 1].Transform;
411+
_tview = View as ITransposeDataView;
412+
TransposeSchema = _tview?.TransposeSchema ?? new TransposerUtils.SimpleTransposeSchema(View.Schema);
415413

416414
var srcLoader = transforms[0].Transform.Source as IDataLoader;
417415

@@ -561,43 +559,34 @@ private static string GenerateTag(int index)
561559

562560
public long? GetRowCount(bool lazy = true)
563561
{
564-
return _view.GetRowCount(lazy);
562+
return View.GetRowCount(lazy);
565563
}
566564

567-
public bool CanShuffle
568-
{
569-
get { return _view.CanShuffle; }
570-
}
565+
public bool CanShuffle => View.CanShuffle;
571566

572-
public ISchema Schema
573-
{
574-
get { return _view.Schema; }
575-
}
567+
public ISchema Schema => View.Schema;
576568

577-
public ITransposeSchema TransposeSchema
578-
{
579-
get { return _tschema; }
580-
}
569+
public ITransposeSchema TransposeSchema { get; }
581570

582571
public IRowCursor GetRowCursor(Func<int, bool> predicate, IRandom rand = null)
583572
{
584573
_host.CheckValue(predicate, nameof(predicate));
585574
_host.CheckValueOrNull(rand);
586-
return _view.GetRowCursor(predicate, rand);
575+
return View.GetRowCursor(predicate, rand);
587576
}
588577

589578
public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator,
590579
Func<int, bool> predicate, int n, IRandom rand = null)
591580
{
592581
_host.CheckValue(predicate, nameof(predicate));
593582
_host.CheckValueOrNull(rand);
594-
return _view.GetRowCursorSet(out consolidator, predicate, n, rand);
583+
return View.GetRowCursorSet(out consolidator, predicate, n, rand);
595584
}
596585

597586
public ISlotCursor GetSlotCursor(int col)
598587
{
599588
_host.CheckParam(0 <= col && col < Schema.ColumnCount, nameof(col));
600-
if (_tschema == null || _tschema.GetSlotType(col) == null)
589+
if (TransposeSchema?.GetSlotType(col) == null)
601590
{
602591
throw _host.ExceptParam(nameof(col), "Bad call to GetSlotCursor on untransposable column '{0}'",
603592
Schema.GetColumnName(col));

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,8 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
232232
{
233233
// All meta-data is passed through in this case, so don't need the slot names type.
234234
echoSrc[i] = true;
235-
DvBool b = DvBool.False;
236235
isNormalized[i] =
237-
info.SrcTypes[0].ItemType.IsNumber &&
238-
Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, info.SrcIndices[0], ref b) &&
239-
b.IsTrue;
236+
info.SrcTypes[0].ItemType.IsNumber && Input.IsNormalized(info.SrcIndices[0]);
240237
types[i] = info.SrcTypes[0];
241238
continue;
242239
}
@@ -247,9 +244,7 @@ private void CacheTypes(out ColumnType[] types, out ColumnType[] typesSlotNames,
247244
{
248245
foreach (var srcCol in info.SrcIndices)
249246
{
250-
DvBool b = DvBool.False;
251-
if (!Input.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, srcCol, ref b) ||
252-
!b.IsTrue)
247+
if (!Input.IsNormalized(srcCol))
253248
{
254249
isNormalized[i] = false;
255250
break;

0 commit comments

Comments
 (0)