Skip to content

Commit 9c1e758

Browse files
committed
finished debugging tests
1 parent 593f7d5 commit 9c1e758

File tree

4 files changed

+98
-32
lines changed

4 files changed

+98
-32
lines changed

src/Microsoft.ML.Transforms/NAIndicatorTransform.cs

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ internal static string TestType(ColumnType type)
8989
type, LoadName);
9090
}
9191

92-
// TODO: Why do we even need an object for this? maybe because we want to deal with array of these things
9392
public class ColumnInfo
9493
{
9594
public readonly string Input;
@@ -156,7 +155,7 @@ protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCo
156155
/// </summary>
157156
/// <param name="env">Host Environment.</param>
158157
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
159-
/// <param name="columns"></param> TODO
158+
/// <param name="columns">Specifies the names of the input columns for the transform and the resulting output column names.</param>
160159
public NAIndicatorTransform(IHostEnvironment env, IDataView input, params ColumnInfo[] columns)
161160
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(NAIndicatorTransform)), GetColumnPairs(columns))
162161
{
@@ -178,7 +177,9 @@ private NAIndicatorTransform(IHost host, ModelLoadContext ctx)
178177
_inputTypes = null;
179178
}
180179

181-
// Factory method for SignatureDataTransform.
180+
/// <summary>
181+
/// Factory method for SignatureDataTransform.
182+
/// </summary>
182183
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
183184
{
184185
Contracts.CheckValue(env, nameof(env));
@@ -196,6 +197,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
196197
return new NAIndicatorTransform(env, input, cols).MakeDataTransform(input);
197198
}
198199

200+
/// <summary>
201+
/// Factory method for SignatureDataTransform.
202+
/// </summary>
199203
public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext ctx)
200204
{
201205
Contracts.CheckValue(env, nameof(env));
@@ -207,14 +211,21 @@ public static NAIndicatorTransform Create(IHostEnvironment env, ModelLoadContext
207211
return new NAIndicatorTransform(host, ctx);
208212
}
209213

214+
/// <summary>
215+
/// Factory method for SignatureDataTransform.
216+
/// </summary>
210217
public static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns)
211218
=> new NAIndicatorTransform(env, input, columns).MakeDataTransform(input);
212219

213-
// Factory method for SignatureLoadDataTransform.
220+
/// <summary>
221+
/// Factory method for SignatureDataTransform.
222+
/// </summary>
214223
public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
215224
=> Create(env, ctx).MakeDataTransform(input);
216225

217-
// Factory method for SignatureLoadRowMapper.
226+
/// <summary>
227+
/// Factory method for SignatureDataTransform.
228+
/// </summary>
218229
public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
219230
=> Create(env, ctx).MakeRowMapper(inputSchema);
220231

@@ -254,7 +265,6 @@ private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action di
254265
Host.AssertValue(input);
255266
Host.Assert(0 <= iinfo && iinfo < ColumnPairs.Length);
256267
disposer = null;
257-
// TODO: is there a better way then checking every time in the getgetter?
258268
if (_outputTypes == null || _inputTypes == null)
259269
(_inputTypes, _outputTypes) = GetTypes(input.Schema);
260270

@@ -269,7 +279,7 @@ private Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action di
269279
private ValueGetter<bool> ComposeGetterOne(IRow input, int iinfo)
270280
{
271281
Func<IRow, int, ValueGetter<bool>> func = ComposeGetterOne<int>;
272-
return Utils.MarshalInvoke(func, _inputTypes[iinfo].RawType, input, iinfo);
282+
return Utils.MarshalInvoke(func, _outputTypes[iinfo].RawType, input, iinfo);
273283
}
274284

275285
private ValueGetter<bool> ComposeGetterOne<T>(IRow input, int iinfo)
@@ -294,7 +304,7 @@ private ValueGetter<bool> ComposeGetterOne<T>(IRow input, int iinfo)
294304
private ValueGetter<VBuffer<bool>> ComposeGetterVec(IRow input, int iinfo)
295305
{
296306
Func<IRow, int, ValueGetter<VBuffer<bool>>> func = ComposeGetterVec<int>;
297-
return Utils.MarshalInvoke(func, _inputTypes[iinfo].ItemType.RawType, input, iinfo);
307+
return Utils.MarshalInvoke(func, _outputTypes[iinfo].ItemType.RawType, input, iinfo);
298308
}
299309

300310
private ValueGetter<VBuffer<bool>> ComposeGetterVec<T>(IRow input, int iinfo)
@@ -506,8 +516,6 @@ public Mapper(NAIndicatorTransform parent, ISchema inputSchema)
506516
_types[i] = BoolType.Instance;
507517
else
508518
_types[i] = new VectorType(BoolType.Instance, type.AsVector);
509-
510-
_types[i] = type;
511519
_isNAs[i] = _parent.GetIsNADelegate(type);
512520
}
513521
}
@@ -626,7 +634,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
626634
string reason = NAIndicatorTransform.TestType(col.ItemType);
627635
if (reason != null)
628636
throw _host.ExceptParam(nameof(inputSchema), reason);
629-
var type = !col.ItemType.IsVector ? col.ItemType : new VectorType(col.ItemType.ItemType.AsPrimitive, col.ItemType.AsVector);
637+
ColumnType type = !col.ItemType.IsVector ? (ColumnType) BoolType.Instance : new VectorType(BoolType.Instance, col.ItemType.AsVector);
630638
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, default);
631639
}
632640
return new SchemaShape(result.Values);
@@ -645,7 +653,7 @@ private interface IColInput
645653
PipelineColumn Input { get; }
646654
}
647655

648-
private sealed class OutScalar<TValue> : Scalar<TValue>, IColInput
656+
private sealed class OutScalar<TValue> : Scalar<bool>, IColInput
649657
{
650658
public PipelineColumn Input { get; }
651659

@@ -656,7 +664,7 @@ public OutScalar(Scalar<TValue> input)
656664
}
657665
}
658666

659-
private sealed class OutVectorColumn<TValue> : Vector<TValue>, IColInput
667+
private sealed class OutVectorColumn<TValue> : Vector<bool>, IColInput
660668
{
661669
public PipelineColumn Input { get; }
662670

@@ -667,7 +675,7 @@ public OutVectorColumn(Vector<TValue> input)
667675
}
668676
}
669677

670-
private sealed class OutVarVectorColumn<TValue> : VarVector<TValue>, IColInput
678+
private sealed class OutVarVectorColumn<TValue> : VarVector<bool>, IColInput
671679
{
672680
public PipelineColumn Input { get; }
673681

@@ -700,55 +708,55 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
700708
}
701709
}
702710

703-
public static Scalar<float> IsMissingValue(this Scalar<float> input)
711+
public static Scalar<bool> IsMissingValue(this Scalar<float> input)
704712
{
705713
Contracts.CheckValue(input, nameof(input));
706714
return new OutScalar<float>(input);
707715
}
708716

709-
public static Scalar<double> IsMissingValue(this Scalar<double> input)
717+
public static Scalar<bool> IsMissingValue(this Scalar<double> input)
710718
{
711719
Contracts.CheckValue(input, nameof(input));
712720
return new OutScalar<double>(input);
713721
}
714722

715-
public static Scalar<string> IsMissingValue(this Scalar<string> input)
723+
public static Scalar<bool> IsMissingValue(this Scalar<string> input)
716724
{
717725
Contracts.CheckValue(input, nameof(input));
718726
return new OutScalar<string>(input);
719727
}
720728

721-
public static Vector<float> IsMissingValue(this Vector<float> input)
729+
public static Vector<bool> IsMissingValue(this Vector<float> input)
722730
{
723731
Contracts.CheckValue(input, nameof(input));
724732
return new OutVectorColumn<float>(input);
725733
}
726734

727-
public static Vector<double> IsMissingValue(this Vector<double> input)
735+
public static Vector<bool> IsMissingValue(this Vector<double> input)
728736
{
729737
Contracts.CheckValue(input, nameof(input));
730738
return new OutVectorColumn<double>(input);
731739
}
732740

733-
public static Vector<string> IsMissingValue(this Vector<string> input)
741+
public static Vector<bool> IsMissingValue(this Vector<string> input)
734742
{
735743
Contracts.CheckValue(input, nameof(input));
736744
return new OutVectorColumn<string>(input);
737745
}
738746

739-
public static VarVector<float> IsMissingValue(this VarVector<float> input)
747+
public static VarVector<bool> IsMissingValue(this VarVector<float> input)
740748
{
741749
Contracts.CheckValue(input, nameof(input));
742750
return new OutVarVectorColumn<float>(input);
743751
}
744752

745-
public static VarVector<double> IsMissingValue(this VarVector<double> input)
753+
public static VarVector<bool> IsMissingValue(this VarVector<double> input)
746754
{
747755
Contracts.CheckValue(input, nameof(input));
748756
return new OutVarVectorColumn<double>(input);
749757
}
750758

751-
public static VarVector<string> IsMissingValue<TValue>(this VarVector<string> input)
759+
public static VarVector<bool> IsMissingValue(this VarVector<string> input)
752760
{
753761
Contracts.CheckValue(input, nameof(input));
754762
return new OutVarVectorColumn<string>(input);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=ScalarFloat:R4:0
5+
#@ col=ScalarDouble:R8:1
6+
#@ col=VectorFloat:R4:2-5
7+
#@ col=VectorDoulbe:R8:6-9
8+
#@ col=A:BL:10
9+
#@ col=B:BL:11
10+
#@ col=C:BL:12-15
11+
#@ col=D:BL:16-19
12+
#@ }
13+
ScalarFloat ScalarDouble 18 8:A 9:B
14+
5 5 5 1 1 1 5 1 1 1 10 0:0
15+
5 5 5 4 4 5 5 4 4 5 10 0:0
16+
3 3 3 1 1 1 3 1 1 1 10 0:0
17+
6 6 6 8 8 1 6 8 8 1 10 0:0
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=ScalarFloat:R4:0
5+
#@ col=ScalarDouble:R8:1
6+
#@ col=VectorFloat:R4:2-5
7+
#@ col=VectorDoulbe:R8:6-9
8+
#@ col=A:BL:10
9+
#@ col=B:BL:11
10+
#@ col=C:BL:12-15
11+
#@ col=D:BL:16-19
12+
#@ }
13+
ScalarFloat ScalarDouble 18 8:A 9:B
14+
5 5 5 1 1 1 5 1 1 1 10 0:0
15+
5 5 5 4 4 5 5 4 4 5 10 0:0
16+
3 3 3 1 1 1 3 1 1 1 10 0:0
17+
6 6 6 8 8 1 6 8 8 1 10 0:0

test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Data;
56
using Microsoft.ML.Runtime.Api;
67
using Microsoft.ML.Runtime.Data;
78
using Microsoft.ML.Runtime.Data.IO;
89
using Microsoft.ML.Runtime.Model;
910
using Microsoft.ML.Runtime.RunTests;
1011
using Microsoft.ML.Runtime.Tools;
12+
using System;
1113
using System.IO;
14+
using System.Linq;
1215
using Xunit;
1316
using Xunit.Abstractions;
1417

@@ -47,20 +50,43 @@ public void NAIndicatorWorkout()
4750
new NAIndicatorTransform.ColumnInfo("B", "NAC"),
4851
new NAIndicatorTransform.ColumnInfo("C", "NAD"),
4952
new NAIndicatorTransform.ColumnInfo("D", "NAE"));
50-
// write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier!
51-
//TestEstimatorCore(pipe, dataView);
53+
TestEstimatorCore(pipe, dataView);
5254
Done();
5355
}
5456

57+
//[Fact]
58+
//public void NAIndicatorSimpleWorkout()
59+
//{
60+
// var data = new[] {
61+
// new TestClass() { A = 1, B = 3, C = new float[2]{ 1, 2 } , D = new double[2]{ 3,4} },
62+
// new TestClass() { A = float.NaN, B = double.NaN, C = new float[2]{ float.NaN, float.NaN } , D = new double[2]{ double.NaN,double.NaN}},
63+
// new TestClass() { A = float.NegativeInfinity, B = double.NegativeInfinity, C = new float[2]{ float.NegativeInfinity, float.NegativeInfinity } , D = new double[2]{ double.NegativeInfinity, double.NegativeInfinity}},
64+
// new TestClass() { A = float.PositiveInfinity, B = double.PositiveInfinity, C = new float[2]{ float.PositiveInfinity, float.PositiveInfinity, } , D = new double[2]{ double.PositiveInfinity, double.PositiveInfinity}},
65+
// new TestClass() { A = 2, B = 1, C = new float[2]{ 3, 4 } , D = new double[2]{ 5,6}},
66+
// };
67+
68+
// var dataView = ComponentCreation.CreateDataView(Env, data);
69+
// var pipe = new NAIndicatorEstimator(Env,
70+
// new NAIndicatorTransform.ColumnInfo("A", "AA"));
71+
// var transform = pipe.Fit(dataView);
72+
// var outDataView = transform.Transform(dataView);
73+
// var col = outDataView.GetColumn<bool>(Env, "AA").ToArray();
74+
// // write a simple test with one NAindicatortransform and try to inspect the columns using pete's code. might be easier!
75+
// //TestEstimatorCore(pipe, dataView);
76+
// foreach(bool c in col)
77+
// {
78+
// Console.WriteLine(c.ToString());
79+
// }
80+
// Done();
81+
//}
82+
5583
[Fact]
5684
public void NAIndicatorStatic()
5785
{
5886
string dataPath = GetDataPath("breast-cancer.txt");
5987
var reader = TextLoader.CreateReader(Env, ctx => (
60-
ScalarString: ctx.LoadText(1),
6188
ScalarFloat: ctx.LoadFloat(1),
6289
ScalarDouble: ctx.LoadDouble(1),
63-
VectorString: ctx.LoadText(1, 4),
6490
VectorFloat: ctx.LoadFloat(1, 4),
6591
VectorDoulbe: ctx.LoadDouble(1, 4)
6692
));
@@ -71,11 +97,10 @@ public void NAIndicatorStatic()
7197

7298
var est = data.MakeNewEstimator().
7399
Append(row => (
74-
A: row.ScalarString.IsMissingValue(),
100+
A: row.ScalarFloat.IsMissingValue(),
75101
B: row.ScalarDouble.IsMissingValue(),
76-
C: row.VectorString.IsMissingValue(),
77-
D: row.VectorFloat.IsMissingValue(),
78-
F: row.VectorDoulbe.IsMissingValue()
102+
C: row.VectorFloat.IsMissingValue(),
103+
D: row.VectorDoulbe.IsMissingValue()
79104
));
80105

81106
TestEstimatorCore(est.AsDynamic, data.AsDynamic, invalidInput: invalidData);
@@ -84,7 +109,6 @@ public void NAIndicatorStatic()
84109
{
85110
var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true });
86111
IDataView savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4);
87-
savedData = new ChooseColumnsTransform(Env, savedData, "A", "B", "C", "D");
88112
using (var fs = File.Create(outputPath))
89113
DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);
90114
}

0 commit comments

Comments
 (0)