Skip to content

Commit 3f586bd

Browse files
authored
handle boolean type in construction utils. (#183)
handle boolean type in construction utils.
1 parent f16737c commit 3f586bd

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed

src/Microsoft.ML.Api/ApiUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ private static OpCode GetAssignmentOpCode(Type t)
2121
// REVIEW: This should be a Dictionary<Type, OpCode> based solution.
2222
// DvTexts, strings, arrays, and VBuffers.
2323
if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) ||
24-
t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
24+
t == typeof(DvBool) || t==typeof(bool?) || t == typeof(DvText) || t == typeof(string) || t.IsArray ||
2525
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || t == typeof(DvDateTime) ||
2626
t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128))
2727
{

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

+37
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ private Delegate CreateGetter(int index)
152152
Ch.Assert(colType.IsText);
153153
return CreateStringToTextGetter(index);
154154
}
155+
else if (outputType == typeof(bool))
156+
{
157+
// Bool -> DvBool
158+
Ch.Assert(colType.IsBool);
159+
return CreateBooleanToDvBoolGetter(index);
160+
}
161+
else if (outputType == typeof(bool?))
162+
{
163+
// Bool? -> DvBool
164+
Ch.Assert(colType.IsBool);
165+
return CreateNullableBooleanToDvBoolGetter(index);
166+
}
167+
155168
// T -> T
156169
Ch.Assert(colType.RawType == outputType);
157170
del = CreateDirectGetter<int>;
@@ -197,6 +210,30 @@ private Delegate CreateStringToTextGetter(int index)
197210
});
198211
}
199212

213+
private Delegate CreateBooleanToDvBoolGetter(int index)
214+
{
215+
var peek = DataView._peeks[index] as Peek<TRow, bool>;
216+
Ch.AssertValue(peek);
217+
bool buf = false;
218+
return (ValueGetter<DvBool>)((ref DvBool dst) =>
219+
{
220+
peek(GetCurrentRowObject(), Position, ref buf);
221+
dst = (DvBool)buf;
222+
});
223+
}
224+
225+
private Delegate CreateNullableBooleanToDvBoolGetter(int index)
226+
{
227+
var peek = DataView._peeks[index] as Peek<TRow, bool?>;
228+
Ch.AssertValue(peek);
229+
bool? buf = null;
230+
return (ValueGetter<DvBool>)((ref DvBool dst) =>
231+
{
232+
peek(GetCurrentRowObject(), Position, ref buf);
233+
dst = buf.HasValue ? (DvBool)buf.Value : DvBool.NA;
234+
});
235+
}
236+
200237
private Delegate CreateArrayToVBufferGetter<TDst>(int index)
201238
{
202239
var peek = DataView._peeks[index] as Peek<TRow, TDst[]>;

src/Microsoft.ML.Core/Data/DataKind.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ public static bool TryGetDataKind(this Type type, out DataKind kind)
207207
kind = DataKind.R8;
208208
else if (type == typeof(DvText))
209209
kind = DataKind.TX;
210-
else if (type == typeof(DvBool) || type == typeof(bool))
210+
else if (type == typeof(DvBool) || type == typeof(bool) ||type ==typeof(bool?))
211211
kind = DataKind.BL;
212212
else if (type == typeof(DvTimeSpan))
213213
kind = DataKind.TS;

test/Microsoft.ML.Tests/LearningPipelineTests.cs

+49
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,54 @@ public void NoTransformPipeline()
110110
pipeline.Add(new FastForestBinaryClassifier());
111111
var model = pipeline.Train<Data, Prediction>();
112112
}
113+
114+
public class BooleanLabelData
115+
{
116+
[ColumnName("Features")]
117+
[VectorType(2)]
118+
public float[] Features;
119+
120+
[ColumnName("Label")]
121+
public bool Label;
122+
}
123+
124+
[Fact]
125+
public void BooleanLabelPipeline()
126+
{
127+
var data = new BooleanLabelData[1];
128+
data[0] = new BooleanLabelData();
129+
data[0].Features = new float[] { 0.0f, 1.0f };
130+
data[0].Label = false;
131+
var pipeline = new LearningPipeline();
132+
pipeline.Add(CollectionDataSource.Create(data));
133+
pipeline.Add(new FastForestBinaryClassifier());
134+
var model = pipeline.Train<Data, Prediction>();
135+
}
136+
137+
public class NullableBooleanLabelData
138+
{
139+
[ColumnName("Features")]
140+
[VectorType(2)]
141+
public float[] Features;
142+
143+
[ColumnName("Label")]
144+
public bool? Label;
145+
}
146+
147+
[Fact]
148+
public void NullableBooleanLabelPipeline()
149+
{
150+
var data = new NullableBooleanLabelData[2];
151+
data[0] = new NullableBooleanLabelData();
152+
data[0].Features = new float[] { 0.0f, 1.0f };
153+
data[0].Label = null;
154+
data[1] = new NullableBooleanLabelData();
155+
data[1].Features = new float[] { 1.0f, 0.0f };
156+
data[1].Label = false;
157+
var pipeline = new LearningPipeline();
158+
pipeline.Add(CollectionDataSource.Create(data));
159+
pipeline.Add(new FastForestBinaryClassifier());
160+
var model = pipeline.Train<Data, Prediction>();
161+
}
113162
}
114163
}

0 commit comments

Comments
 (0)