Skip to content

Commit 731381c

Browse files
authored
NAReplace estimator (#917)
1 parent d13b415 commit 731381c

File tree

12 files changed

+913
-535
lines changed

12 files changed

+913
-535
lines changed

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema)
280280
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc))
281281
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input);
282282
var type = inputSchema.GetColumnType(colSrc);
283+
_parent.CheckInputColumn(inputSchema, i, colSrc);
283284
infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type);
284285
}
285286
return infos;

src/Microsoft.ML.Legacy/CSharpApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -13983,7 +13983,7 @@ public MinMaxNormalizerPipelineStep(Output output)
1398313983

1398413984
namespace Legacy.Transforms
1398513985
{
13986-
public enum NAHandleTransformReplacementKind
13986+
public enum NAHandleTransformReplacementKind : byte
1398713987
{
1398813988
DefaultValue = 0,
1398913989
Mean = 1,
@@ -14444,7 +14444,7 @@ public MissingValuesRowDropperPipelineStep(Output output)
1444414444

1444514445
namespace Legacy.Transforms
1444614446
{
14447-
public enum NAReplaceTransformReplacementKind
14447+
public enum NAReplaceTransformReplacementKind : byte
1444814448
{
1444914449
DefaultValue = 0,
1445014450
Mean = 1,

src/Microsoft.ML.Transforms/KeyToBinaryVectorTransform.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ private static (string input, string output)[] GetColumnPairs(ColumnInfo[] colum
6969
Contracts.CheckValue(columns, nameof(columns));
7070
return columns.Select(x => (x.Input, x.Output)).ToArray();
7171
}
72+
7273
public IReadOnlyCollection<ColumnInfo> Columns => _columns.AsReadOnly();
7374
private readonly ColumnInfo[] _columns;
7475

@@ -209,7 +210,7 @@ private ColInfo[] CreateInfos(ISchema inputSchema)
209210
if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc))
210211
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input);
211212
var type = inputSchema.GetColumnType(colSrc);
212-
213+
_parent.CheckInputColumn(inputSchema, i, colSrc);
213214
infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type);
214215
}
215216
return infos;

src/Microsoft.ML.Transforms/NAHandleTransform.cs

+11-35
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,28 @@ namespace Microsoft.ML.Runtime.Data
2020
/// <include file='doc.xml' path='doc/members/member[@name="NAHandle"]'/>
2121
public static class NAHandleTransform
2222
{
23-
public enum ReplacementKind
23+
public enum ReplacementKind : byte
2424
{
2525
/// <summary>
2626
/// Replace with the default value of the column based on it's type. For example, 'zero' for numeric and 'empty' for string/text columns.
2727
/// </summary>
2828
[EnumValueDisplay("Zero/empty")]
29-
DefaultValue,
29+
DefaultValue = 0,
3030

3131
/// <summary>
3232
/// Replace with the mean value of the column. Supports only numeric/time span/ DateTime columns.
3333
/// </summary>
34-
Mean,
34+
Mean = 1,
3535

3636
/// <summary>
3737
/// Replace with the minimum value of the column. Supports only numeric/time span/ DateTime columns.
3838
/// </summary>
39-
Minimum,
39+
Minimum = 2,
4040

4141
/// <summary>
4242
/// Replace with the maximum value of the column. Supports only numeric/time span/ DateTime columns.
4343
/// </summary>
44-
Maximum,
44+
Maximum = 3,
4545

4646
[HideEnumValue]
4747
Def = DefaultValue,
@@ -135,7 +135,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
135135
h.CheckValue(input, nameof(input));
136136
h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));
137137

138-
var replaceCols = new List<NAReplaceTransform.Column>();
138+
var replaceCols = new List<NAReplaceTransform.ColumnInfo>();
139139
var naIndicatorCols = new List<NAIndicatorTransform.Column>();
140140
var naConvCols = new List<ConvertTransform.Column>();
141141
var concatCols = new List<ConcatTransform.TaggedColumn>();
@@ -149,26 +149,16 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
149149
var addInd = column.ConcatIndicator ?? args.Concat;
150150
if (!addInd)
151151
{
152-
replaceCols.Add(
153-
new NAReplaceTransform.Column()
154-
{
155-
Kind = (NAReplaceTransform.ReplacementKind?)column.Kind,
156-
Name = column.Name,
157-
Source = column.Source,
158-
Slot = column.ImputeBySlot
159-
});
152+
replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
160153
continue;
161154
}
162155

163156
// Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
164157
// so that they can be concatenated.
165-
int inputCol;
166-
if (!input.Schema.TryGetColumnIndex(column.Source, out inputCol))
158+
if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
167159
throw h.Except("Column '{0}' does not exist", column.Source);
168160
var replaceType = input.Schema.GetColumnType(inputCol);
169-
Delegate conv;
170-
bool identity;
171-
if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out conv, out identity))
161+
if (!Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
172162
{
173163
throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
174164
BoolType.Instance, replaceType.ItemType);
@@ -186,14 +176,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
186176
naConvCols.Add(new ConvertTransform.Column() { Name = tmpIsMissingColName, Source = tmpIsMissingColName, ResultType = replaceType.ItemType.RawKind });
187177

188178
// Add the NAReplaceTransform column.
189-
replaceCols.Add(
190-
new NAReplaceTransform.Column()
191-
{
192-
Kind = (NAReplaceTransform.ReplacementKind?)column.Kind,
193-
Name = tmpReplacementColName,
194-
Source = column.Source,
195-
Slot = column.ImputeBySlot
196-
});
179+
replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
197180

198181
// Add the ConcatTransform column.
199182
if (replaceType.IsVector)
@@ -237,15 +220,8 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
237220
h.AssertValue(output);
238221
output = new ConvertTransform(h, new ConvertTransform.Arguments() { Column = naConvCols.ToArray() }, output);
239222
}
240-
241223
// Create the NAReplace transform.
242-
output = new NAReplaceTransform(h,
243-
new NAReplaceTransform.Arguments()
244-
{
245-
Column = replaceCols.ToArray(),
246-
ReplacementKind = (NAReplaceTransform.ReplacementKind)args.ReplaceWith,
247-
ImputeBySlot = args.ImputeBySlot
248-
}, output ?? input);
224+
output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());
249225

250226
// Concat the NAReplaceTransform output and the NAIndicatorTransform output.
251227
if (naIndicatorCols.Count > 0)

src/Microsoft.ML.Transforms/NAHandling.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public static CommonOutputs.TransformOutput Indicator(IHostEnvironment env, NAIn
8888
public static CommonOutputs.TransformOutput Replace(IHostEnvironment env, NAReplaceTransform.Arguments input)
8989
{
9090
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NAReplace", input);
91-
var xf = new NAReplaceTransform(h, input, input.Data);
91+
var xf = NAReplaceTransform.Create(h, input, input.Data);
9292
return new CommonOutputs.TransformOutput()
9393
{
9494
Model = new TransformModel(h, xf, input.Data),

0 commit comments

Comments
 (0)