Skip to content

Commit 1b0fff9

Browse files
committed
Replace DropColumns,KeepColumns and ChooseColumns with SelectColumns
This adds the SelectColumns Transform and Estimator that is replacing the DropColumns and ChooseColumns Transforms. With this check-in, Drop and Choose are still in the code base but will be removed. In order to support loading older models, SelectColumns supports loading in Drop and Choose transforms. The changes include: - Implementation of the SelectColumnsTransform, SelectColumnsDataTransform and SelectColumnsEstimator - Backward compatibility with Drop and Choose columns by providing functions on SelectColumns that will be called when loading the model. - Entry point apis for calling select from the command line. - Additional tests. These changes are related to dotnet#754.
1 parent bee7f17 commit 1b0fff9

File tree

18 files changed

+1181
-158
lines changed

18 files changed

+1181
-158
lines changed

src/Microsoft.ML.Core/Data/Schema.cs

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
using Microsoft.ML.Runtime.Data;
5+
using Microsoft.ML.Runtime.Internal.Utilities;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
using System.Text;
10+
namespace Microsoft.ML.Runtime.Data
11+
{
12+
#pragma warning disable CS0618 // Type or member is obsolete
13+
public sealed class Schema : ISchema
14+
#pragma warning restore CS0618 // Type or member is obsolete
15+
{
16+
private readonly Column[] _columns;
17+
private readonly Dictionary<string, int> _nameMap;
18+
public int ColumnCount => _columns.Length;
19+
public Column this[string name]
20+
{
21+
get
22+
{
23+
Contracts.CheckValue(name, nameof(name));
24+
if (!_nameMap.TryGetValue(name, out int col))
25+
throw Contracts.ExceptParam(nameof(name), $"Column '{name}' not found");
26+
return _columns[col];
27+
}
28+
}
29+
public Column this[int col]
30+
{
31+
get
32+
{
33+
Contracts.CheckParam(0 <= col && col < _columns.Length, nameof(col));
34+
return _columns[col];
35+
}
36+
}
37+
public bool TryGetColumnIndex(string name, out int col) => _nameMap.TryGetValue(name, out col);
38+
public Column GetColumnOrNull(string name)
39+
{
40+
Contracts.CheckNonEmpty(name, nameof(name));
41+
if (_nameMap.TryGetValue(name, out int col))
42+
return _columns[col];
43+
return null;
44+
}
45+
public sealed class Column
46+
{
47+
public string Name { get; }
48+
public ColumnType Type { get; }
49+
public MetadataRow Metadata { get; }
50+
public Column(string name, ColumnType type, MetadataRow metadata)
51+
{
52+
Contracts.CheckNonEmpty(name, nameof(name));
53+
Contracts.CheckValue(type, nameof(type));
54+
Contracts.CheckValueOrNull(metadata);
55+
Name = name;
56+
Type = type;
57+
Metadata = metadata;
58+
}
59+
}
60+
public sealed class MetadataRow
61+
{
62+
private readonly (Column column, Delegate getter)[] _values;
63+
public Schema Schema { get; }
64+
public MetadataRow(IEnumerable<(Column column, Delegate getter)> values)
65+
{
66+
Contracts.CheckValue(values, nameof(values));
67+
// Check all getters.
68+
foreach (var (column, getter) in values)
69+
{
70+
Contracts.CheckValue(column, nameof(column));
71+
Contracts.CheckValue(getter, nameof(getter));
72+
Utils.MarshalActionInvoke(CheckGetter<int>, column.Type.RawType, getter);
73+
}
74+
_values = values.ToArray();
75+
Schema = new Schema(_values.Select(x => x.column));
76+
}
77+
private void CheckGetter<TValue>(Delegate getter)
78+
{
79+
var typedGetter = getter as ValueGetter<TValue>;
80+
if (typedGetter == null)
81+
throw Contracts.ExceptParam(nameof(getter), $"Getter of type '{typeof(TValue)}' expected, but {getter.GetType()} found");
82+
}
83+
public ValueGetter<TValue> GetGetter<TValue>(int col)
84+
{
85+
Contracts.CheckParam(0 <= col && col < _values.Length, nameof(col));
86+
var typedGetter = _values[col].getter as ValueGetter<TValue>;
87+
return typedGetter;
88+
}
89+
/// <summary>
90+
/// The class that incrementally builds a <see cref="MetadataRow"/>.
91+
/// </summary>
92+
public sealed class Builder
93+
{
94+
private readonly List<(Column column, Delegate getter)> _items;
95+
public Builder()
96+
{
97+
_items = new List<(Column column, Delegate getter)>();
98+
}
99+
public void Add(MetadataRow metadata, Func<string, bool> selector)
100+
{
101+
Contracts.CheckValueOrNull(metadata);
102+
Contracts.CheckValue(selector, nameof(selector));
103+
if (metadata == null)
104+
return;
105+
foreach (var pair in metadata._values)
106+
{
107+
if (selector(pair.column.Name))
108+
_items.Add(pair);
109+
}
110+
}
111+
public void Add<TValue>(Column column, ValueGetter<TValue> getter)
112+
{
113+
Contracts.CheckValue(column, nameof(column));
114+
Contracts.CheckValue(getter, nameof(getter));
115+
Contracts.CheckParam(column.Type.RawType == typeof(TValue), nameof(getter));
116+
_items.Add((column, getter));
117+
}
118+
public void Add(Column column, Delegate getter)
119+
{
120+
Contracts.CheckValue(column, nameof(column));
121+
Utils.MarshalActionInvoke(AddDelegate<int>, column.Type.RawType, column, getter);
122+
}
123+
public void AddSlotNames(int size, ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter)
124+
=> Add(new Column(MetadataUtils.Kinds.SlotNames, new VectorType(TextType.Instance, size), null), getter);
125+
public void AddKeyValues<TValue>(int size, PrimitiveType valueType, ValueGetter<VBuffer<TValue>> getter)
126+
=> Add(new Column(MetadataUtils.Kinds.KeyValues, new VectorType(valueType, size), null), getter);
127+
public MetadataRow GetMetadataRow() => new MetadataRow(_items);
128+
private void AddDelegate<TValue>(Schema.Column column, Delegate getter)
129+
{
130+
Contracts.CheckValue(column, nameof(column));
131+
Contracts.CheckValue(getter, nameof(getter));
132+
var typedGetter = getter as ValueGetter<TValue>;
133+
Contracts.CheckParam(typedGetter != null, nameof(getter));
134+
_items.Add((column, typedGetter));
135+
}
136+
}
137+
public void GetValue<TValue>(int col, ref TValue value) => GetGetter<TValue>(col)(ref value);
138+
public void GetValue<TValue>(string kind, ref TValue value)
139+
{
140+
if (!Schema.TryGetColumnIndex(kind, out int col))
141+
throw MetadataUtils.ExceptGetMetadata();
142+
GetValue(col, ref value);
143+
}
144+
}
145+
public Schema(IEnumerable<Column> columns)
146+
{
147+
Contracts.CheckValue(columns, nameof(columns));
148+
_columns = columns.ToArray();
149+
_nameMap = new Dictionary<string, int>();
150+
for (int i = 0; i < _columns.Length; i++)
151+
_nameMap[_columns[i].Name] = i;
152+
}
153+
public IEnumerable<(int index, Column column)> GetColumns() => _nameMap.Values.Select(idx => (idx, _columns[idx]));
154+
#pragma warning disable CS0618 // Type or member is obsolete
155+
/// <summary>
156+
/// Manufacture an instance of <see cref="Schema"/> out of any <see cref="ISchema"/>.
157+
/// </summary>
158+
public static Schema Create(ISchema inputSchema)
159+
#pragma warning restore CS0618 // Type or member is obsolete
160+
{
161+
Contracts.CheckValue(inputSchema, nameof(inputSchema));
162+
if (inputSchema is Schema s)
163+
return s;
164+
var columns = new Column[inputSchema.ColumnCount];
165+
for (int i = 0; i < columns.Length; i++)
166+
{
167+
var meta = new MetadataRow.Builder();
168+
foreach (var kvp in inputSchema.GetMetadataTypes(i))
169+
{
170+
var getter = Utils.MarshalInvoke(GetMetadataGetterDelegate<int>, kvp.Value.RawType, inputSchema, i, kvp.Key);
171+
meta.Add(new Column(kvp.Key, kvp.Value, null), getter);
172+
}
173+
columns[i] = new Column(inputSchema.GetColumnName(i), inputSchema.GetColumnType(i), meta.GetMetadataRow());
174+
}
175+
return new Schema(columns);
176+
}
177+
#pragma warning disable CS0618 // Type or member is obsolete
178+
private static Delegate GetMetadataGetterDelegate<TValue>(ISchema schema, int col, string kind)
179+
#pragma warning restore CS0618 // Type or member is obsolete
180+
{
181+
// REVIEW: We are facing a choice here: cache 'value' and get rid of 'schema' reference altogether,
182+
// or retain the reference but be more memory efficient. This code should not stick around for too long
183+
// anyway, so let's not sweat too much, and opt for the latter.
184+
ValueGetter<TValue> getter = (ref TValue value) => schema.GetMetadata(kind, col, ref value);
185+
return getter;
186+
}
187+
#region Legacy schema API to be removed
188+
public string GetColumnName(int col) => this[col].Name;
189+
public ColumnType GetColumnType(int col) => this[col].Type;
190+
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
191+
{
192+
var meta = this[col].Metadata;
193+
if (meta == null)
194+
return Enumerable.Empty<KeyValuePair<string, ColumnType>>();
195+
return meta.Schema.GetColumns().Select(c => new KeyValuePair<string, ColumnType>(c.column.Name, c.column.Type));
196+
}
197+
public ColumnType GetMetadataTypeOrNull(string kind, int col)
198+
{
199+
var meta = this[col].Metadata;
200+
if (meta == null)
201+
return null;
202+
if (meta.Schema.TryGetColumnIndex(kind, out int metaCol))
203+
return meta.Schema[metaCol].Type;
204+
return null;
205+
}
206+
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
207+
{
208+
var meta = this[col].Metadata;
209+
if (meta == null)
210+
throw MetadataUtils.ExceptGetMetadata();
211+
if (!meta.Schema.TryGetColumnIndex(kind, out int metaCol))
212+
throw MetadataUtils.ExceptGetMetadata();
213+
meta.GetValue(metaCol, ref value);
214+
}
215+
#endregion
216+
}
217+
}

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

+2-16
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@ public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env,
2626
}
2727

2828
[TlcModule.EntryPoint(Name = "Transforms.ColumnSelector", Desc = "Selects a set of columns, dropping all others", UserName = "Select Columns")]
29-
public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, DropColumnsTransform.KeepArguments input)
29+
public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, SelectColumnsTransform.Arguments input)
3030
{
3131
Contracts.CheckValue(env, nameof(env));
3232
var host = env.Register("SelectColumns");
3333
host.CheckValue(input, nameof(input));
3434
EntryPointUtils.CheckInputArgs(host, input);
35-
// We can have an empty Columns array, indicating we
36-
// wish to drop all the columns.
3735

38-
var xf = new DropColumnsTransform(env, input, input.Data);
36+
var xf = new SelectColumnsTransform(env, input.KeepHidden, input.Columns).Transform(input.Data);
3937
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
4038
}
4139

@@ -49,17 +47,5 @@ public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, Co
4947
var xf = CopyColumnsTransform.Create(env, input, input.Data);
5048
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
5149
}
52-
53-
[TlcModule.EntryPoint(Name = "Transforms.ColumnDropper", Desc = "Drops columns from the dataset", UserName = DropColumnsTransform.DropUserName, ShortName = DropColumnsTransform.DropShortName)]
54-
public static CommonOutputs.TransformOutput DropColumns(IHostEnvironment env, DropColumnsTransform.Arguments input)
55-
{
56-
Contracts.CheckValue(env, nameof(env));
57-
var host = env.Register("DropColumns");
58-
host.CheckValue(input, nameof(input));
59-
EntryPointUtils.CheckInputArgs(host, input);
60-
61-
var xf = new DropColumnsTransform(env, input, input.Data);
62-
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
63-
}
6450
}
6551
}

src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs

-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
[assembly: LoadableClass(typeof(ChooseColumnsTransform), typeof(ChooseColumnsTransform.Arguments), typeof(SignatureDataTransform),
1818
"Choose Columns Transform", "ChooseColumnsTransform", "ChooseColumns", "Choose", DocName = "transform/DropKeepChooseTransforms.md")]
1919

20-
[assembly: LoadableClass(typeof(ChooseColumnsTransform), null, typeof(SignatureLoadDataTransform),
21-
"Choose Columns Transform", ChooseColumnsTransform.LoaderSignature, ChooseColumnsTransform.LoaderSignatureOld)]
22-
2320
namespace Microsoft.ML.Runtime.Data
2421
{
2522
public sealed class ChooseColumnsTransform : RowToRowTransformBase

src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs

-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
[assembly: LoadableClass(DropColumnsTransform.KeepColumnsSummary, typeof(DropColumnsTransform), typeof(DropColumnsTransform.KeepArguments), typeof(SignatureDataTransform),
2020
DropColumnsTransform.KeepUserName, "KeepColumns", "KeepColumnsTransform", DropColumnsTransform.KeepShortName, DocName = "transform/DropKeepChooseTransforms.md")]
2121

22-
[assembly: LoadableClass(DropColumnsTransform.DropColumnsSummary, typeof(DropColumnsTransform), null, typeof(SignatureLoadDataTransform),
23-
DropColumnsTransform.DropUserName, DropColumnsTransform.LoaderSignature)]
24-
2522
namespace Microsoft.ML.Runtime.Data
2623
{
2724
/// <summary>

0 commit comments

Comments
 (0)