Skip to content

Commit 0b239ea

Browse files
author
Pete Luferenko
committed
Fixed tests
1 parent dc0f5bf commit 0b239ea

File tree

1 file changed

+111
-95
lines changed

1 file changed

+111
-95
lines changed

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

+111-95
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
[assembly: LoadableClass(ConcatTransform.Summary, typeof(IDataTransform), typeof(ConcatTransform), null, typeof(SignatureLoadDataTransform),
2727
ConcatTransform.UserName, ConcatTransform.LoaderSignature, ConcatTransform.LoaderSignatureOld)]
2828

29+
[assembly: LoadableClass(typeof(ConcatTransform), null, typeof(SignatureLoadModel),
30+
ConcatTransform.UserName, ConcatTransform.LoaderSignature)]
31+
32+
[assembly: LoadableClass(typeof(IRowMapper), typeof(ConcatTransform), null, typeof(SignatureLoadRowMapper),
33+
ConcatTransform.UserName, ConcatTransform.LoaderSignature)]
34+
2935
namespace Microsoft.ML.Runtime.Data
3036
{
3137
using PfaType = PfaUtils.Type;
@@ -250,6 +256,9 @@ public void Save(ModelSaveContext ctx)
250256
col.Save(ctx);
251257
}
252258

259+
/// <summary>
260+
/// Constructor for SignatureLoadModel.
261+
/// </summary>
253262
public ConcatTransform(IHostEnvironment env, ModelLoadContext ctx)
254263
{
255264
Contracts.CheckValue(env, nameof(env));
@@ -770,11 +779,17 @@ private IDataTransform MakeDataTransform(IDataView input)
770779
public IRowMapper MakeRowMapper(ISchema inputSchema) => new Mapper(this, inputSchema);
771780

772781
/// <summary>
773-
/// Factory method for SignatureLoadDataTransform
782+
/// Factory method for SignatureLoadDataTransform.
774783
/// </summary>
775784
public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
776785
=> new ConcatTransform(env, ctx).MakeDataTransform(input);
777786

787+
/// <summary>
788+
/// Factory method for SignatureLoadRowMapper.
789+
/// </summary>
790+
public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
791+
=> new ConcatTransform(env, ctx).MakeRowMapper(inputSchema);
792+
778793
public ISchema GetOutputSchema(ISchema inputSchema)
779794
{
780795
_host.CheckValue(inputSchema, nameof(inputSchema));
@@ -933,16 +948,16 @@ public RowMapperColumnInfo MakeColumnInfo()
933948

934949
var metadata = new ColumnMetadataInfo(_columnInfo.Output);
935950
if (_isNormalized)
936-
metadata.Add(MetadataUtils.Kinds.IsNormalized, new MetadataInfo<bool>(BoolType.Instance, GetIsNormalized));
951+
metadata.Add(MetadataUtils.Kinds.IsNormalized, new MetadataInfo<DvBool>(BoolType.Instance, GetIsNormalized));
937952
if (_hasSlotNames)
938-
metadata.Add(MetadataUtils.Kinds.SlotNames, new MetadataInfo<VBuffer<DvText>>(TextType.Instance, GetSlotNames));
953+
metadata.Add(MetadataUtils.Kinds.SlotNames, new MetadataInfo<VBuffer<DvText>>(_slotNamesType, GetSlotNames));
939954
if (_hasCategoricals)
940-
metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges, new MetadataInfo<VBuffer<DvInt4>>(TextType.Instance, GetCategoricalSlotRanges));
955+
metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges, new MetadataInfo<VBuffer<DvInt4>>(_categoricalRangeType, GetCategoricalSlotRanges));
941956

942957
return new RowMapperColumnInfo(_columnInfo.Output, OutputType, metadata);
943958
}
944959

945-
private void GetIsNormalized(int col, ref bool value) => value = _isNormalized;
960+
private void GetIsNormalized(int col, ref DvBool value) => value = _isNormalized;
946961

947962
private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer<DvInt4> dst)
948963
{
@@ -1025,17 +1040,18 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
10251040
public Delegate MakeGetter(IRow input)
10261041
{
10271042
if (_isIdentity)
1028-
{
1029-
Contracts.Assert(SrcIndices.Length == 1);
1030-
Func<Delegate> getSrcGetter = () => input.GetGetter<int>(SrcIndices[0]);
1031-
return Utils.MarshalInvoke(getSrcGetter, _srcTypes[0].RawType);
1032-
}
1043+
return Utils.MarshalInvoke(MakeIdentityGetter<int>, OutputType.RawType, input);
10331044

1034-
Func<IRow, ValueGetter<VBuffer<int>>> del = MakeGetter<int>;
1035-
return Utils.MarshalInvoke(MakeGetter<int>, _srcTypes[0].RawType, input);
1045+
return Utils.MarshalInvoke(MakeGetter<int>, OutputType.ItemType.RawType, input);
1046+
}
1047+
1048+
private Delegate MakeIdentityGetter<T>(IRow input)
1049+
{
1050+
Contracts.Assert(SrcIndices.Length == 1);
1051+
return input.GetGetter<T>(SrcIndices[0]);
10361052
}
10371053

1038-
private ValueGetter<VBuffer<T>> MakeGetter<T>(IRow input)
1054+
private Delegate MakeGetter<T>(IRow input)
10391055
{
10401056
var srcGetterOnes = new ValueGetter<T>[SrcIndices.Length];
10411057
var srcGetterVecs = new ValueGetter<VBuffer<T>>[SrcIndices.Length];
@@ -1049,109 +1065,109 @@ private ValueGetter<VBuffer<T>> MakeGetter<T>(IRow input)
10491065

10501066
T tmp = default(T);
10511067
VBuffer<T>[] tmpBufs = new VBuffer<T>[SrcIndices.Length];
1052-
return
1053-
(ref VBuffer<T> dst) =>
1068+
ValueGetter<VBuffer<T>> result = (ref VBuffer<T> dst) =>
1069+
{
1070+
int dstLength = 0;
1071+
int dstCount = 0;
1072+
for (int i = 0; i < SrcIndices.Length; i++)
10541073
{
1055-
int dstLength = 0;
1056-
int dstCount = 0;
1057-
for (int i = 0; i < SrcIndices.Length; i++)
1074+
var type = _srcTypes[i];
1075+
if (type.IsVector)
10581076
{
1059-
var type = _srcTypes[i];
1060-
if (type.IsVector)
1061-
{
1062-
srcGetterVecs[i](ref tmpBufs[i]);
1063-
if (type.VectorSize != 0 && type.VectorSize != tmpBufs[i].Length)
1064-
{
1065-
throw Contracts.Except("Column '{0}': expected {1} slots, but got {2}",
1066-
input.Schema.GetColumnName(SrcIndices[i]), type.VectorSize, tmpBufs[i].Length)
1067-
.MarkSensitive(MessageSensitivity.Schema);
1068-
}
1069-
dstLength = checked(dstLength + tmpBufs[i].Length);
1070-
dstCount = checked(dstCount + tmpBufs[i].Count);
1071-
}
1072-
else
1077+
srcGetterVecs[i](ref tmpBufs[i]);
1078+
if (type.VectorSize != 0 && type.VectorSize != tmpBufs[i].Length)
10731079
{
1074-
dstLength = checked(dstLength + 1);
1075-
dstCount = checked(dstCount + 1);
1080+
throw Contracts.Except("Column '{0}': expected {1} slots, but got {2}",
1081+
input.Schema.GetColumnName(SrcIndices[i]), type.VectorSize, tmpBufs[i].Length)
1082+
.MarkSensitive(MessageSensitivity.Schema);
10761083
}
1084+
dstLength = checked(dstLength + tmpBufs[i].Length);
1085+
dstCount = checked(dstCount + tmpBufs[i].Count);
10771086
}
1087+
else
1088+
{
1089+
dstLength = checked(dstLength + 1);
1090+
dstCount = checked(dstCount + 1);
1091+
}
1092+
}
10781093

1079-
var values = dst.Values;
1080-
var indices = dst.Indices;
1081-
if (dstCount <= dstLength / 2)
1094+
var values = dst.Values;
1095+
var indices = dst.Indices;
1096+
if (dstCount <= dstLength / 2)
1097+
{
1098+
// Concatenate into a sparse representation.
1099+
if (Utils.Size(values) < dstCount)
1100+
values = new T[dstCount];
1101+
if (Utils.Size(indices) < dstCount)
1102+
indices = new int[dstCount];
1103+
1104+
int offset = 0;
1105+
int count = 0;
1106+
for (int j = 0; j < SrcIndices.Length; j++)
10821107
{
1083-
// Concatenate into a sparse representation.
1084-
if (Utils.Size(values) < dstCount)
1085-
values = new T[dstCount];
1086-
if (Utils.Size(indices) < dstCount)
1087-
indices = new int[dstCount];
1088-
1089-
int offset = 0;
1090-
int count = 0;
1091-
for (int j = 0; j < SrcIndices.Length; j++)
1108+
Contracts.Assert(offset < dstLength);
1109+
if (_srcTypes[j].IsVector)
10921110
{
1093-
Contracts.Assert(offset < dstLength);
1094-
if (_srcTypes[j].IsVector)
1111+
var buffer = tmpBufs[j];
1112+
Contracts.Assert(buffer.Count <= dstCount - count);
1113+
Contracts.Assert(buffer.Length <= dstLength - offset);
1114+
if (buffer.IsDense)
10951115
{
1096-
var buffer = tmpBufs[j];
1097-
Contracts.Assert(buffer.Count <= dstCount - count);
1098-
Contracts.Assert(buffer.Length <= dstLength - offset);
1099-
if (buffer.IsDense)
1116+
for (int i = 0; i < buffer.Length; i++)
11001117
{
1101-
for (int i = 0; i < buffer.Length; i++)
1102-
{
1103-
values[count] = buffer.Values[i];
1104-
indices[count++] = offset + i;
1105-
}
1118+
values[count] = buffer.Values[i];
1119+
indices[count++] = offset + i;
11061120
}
1107-
else
1108-
{
1109-
for (int i = 0; i < buffer.Count; i++)
1110-
{
1111-
values[count] = buffer.Values[i];
1112-
indices[count++] = offset + buffer.Indices[i];
1113-
}
1114-
}
1115-
offset += buffer.Length;
11161121
}
11171122
else
11181123
{
1119-
Contracts.Assert(count < dstCount);
1120-
srcGetterOnes[j](ref tmp);
1121-
values[count] = tmp;
1122-
indices[count++] = offset;
1123-
offset++;
1124+
for (int i = 0; i < buffer.Count; i++)
1125+
{
1126+
values[count] = buffer.Values[i];
1127+
indices[count++] = offset + buffer.Indices[i];
1128+
}
11241129
}
1130+
offset += buffer.Length;
1131+
}
1132+
else
1133+
{
1134+
Contracts.Assert(count < dstCount);
1135+
srcGetterOnes[j](ref tmp);
1136+
values[count] = tmp;
1137+
indices[count++] = offset;
1138+
offset++;
11251139
}
1126-
Contracts.Assert(count <= dstCount);
1127-
Contracts.Assert(offset == dstLength);
1128-
dst = new VBuffer<T>(dstLength, count, values, indices);
11291140
}
1130-
else
1131-
{
1132-
// Concatenate into a dense representation.
1133-
if (Utils.Size(values) < dstLength)
1134-
values = new T[dstLength];
1141+
Contracts.Assert(count <= dstCount);
1142+
Contracts.Assert(offset == dstLength);
1143+
dst = new VBuffer<T>(dstLength, count, values, indices);
1144+
}
1145+
else
1146+
{
1147+
// Concatenate into a dense representation.
1148+
if (Utils.Size(values) < dstLength)
1149+
values = new T[dstLength];
11351150

1136-
int offset = 0;
1137-
for (int j = 0; j < SrcIndices.Length; j++)
1151+
int offset = 0;
1152+
for (int j = 0; j < SrcIndices.Length; j++)
1153+
{
1154+
Contracts.Assert(tmpBufs[j].Length <= dstLength - offset);
1155+
if (_srcTypes[j].IsVector)
11381156
{
1139-
Contracts.Assert(tmpBufs[j].Length <= dstLength - offset);
1140-
if (_srcTypes[j].IsVector)
1141-
{
1142-
tmpBufs[j].CopyTo(values, offset);
1143-
offset += tmpBufs[j].Length;
1144-
}
1145-
else
1146-
{
1147-
srcGetterOnes[j](ref tmp);
1148-
values[offset++] = tmp;
1149-
}
1157+
tmpBufs[j].CopyTo(values, offset);
1158+
offset += tmpBufs[j].Length;
1159+
}
1160+
else
1161+
{
1162+
srcGetterOnes[j](ref tmp);
1163+
values[offset++] = tmp;
11501164
}
1151-
Contracts.Assert(offset == dstLength);
1152-
dst = new VBuffer<T>(dstLength, values, indices);
11531165
}
1154-
};
1166+
Contracts.Assert(offset == dstLength);
1167+
dst = new VBuffer<T>(dstLength, values, indices);
1168+
}
1169+
};
1170+
return result;
11551171
}
11561172

11571173
public KeyValuePair<string, JToken> SavePfaInfo(BoundPfaContext ctx)

0 commit comments

Comments
 (0)