Skip to content

Commit 7d7ebb6

Browse files
authored
TensorFlow: Fixed shape issue where unknown shape will be induced from data. (#2475)
1 parent 07580a8 commit 7d7ebb6

File tree

3 files changed

+135
-29
lines changed

3 files changed

+135
-29
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
<PropertyGroup>
4444
<BenchmarkDotNetVersion>0.11.3</BenchmarkDotNetVersion>
4545
<MicrosoftMLTestModelsPackageVersion>0.0.3-test</MicrosoftMLTestModelsPackageVersion>
46-
<MicrosoftMLTensorFlowTestModelsVersion>0.0.7-test</MicrosoftMLTensorFlowTestModelsVersion>
46+
<MicrosoftMLTensorFlowTestModelsVersion>0.0.10-test</MicrosoftMLTensorFlowTestModelsVersion>
4747
<MicrosoftMLOnnxTestModelsVersion>0.0.4-test</MicrosoftMLOnnxTestModelsVersion>
4848
</PropertyGroup>
4949

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -667,15 +667,6 @@ internal static (TFDataType[] tfInputTypes, TFShape[] tfInputShapes) GetInputInf
667667
var tfInput = new TFOutput(session.Graph[inputs[i]]);
668668
tfInputTypes[i] = tfInput.OutputType;
669669
tfInputShapes[i] = session.Graph.GetTensorShape(tfInput);
670-
if (tfInputShapes[i].NumDimensions != -1)
671-
{
672-
var newShape = new long[tfInputShapes[i].NumDimensions];
673-
newShape[0] = tfInputShapes[i][0] == -1 ? BatchSize : tfInputShapes[i][0];
674-
675-
for (int j = 1; j < tfInputShapes[i].NumDimensions; j++)
676-
newShape[j] = tfInputShapes[i][j];
677-
tfInputShapes[i] = new TFShape(newShape);
678-
}
679670
}
680671
return (tfInputTypes, tfInputShapes);
681672
}
@@ -698,7 +689,14 @@ internal static (TFDataType[] tfOutputTypes, ColumnType[] outputTypes) GetOutput
698689
{
699690
var tfOutput = new TFOutput(session.Graph[outputs[i]]);
700691
var shape = session.Graph.GetTensorShape(tfOutput);
692+
693+
// The transformer can only retreive the output as fixed length vector with shape of kind [-1, d1, d2, d3, ...]
694+
// i.e. the first dimension (if unknown) is assumed to be batch dimension.
695+
// If there are other dimension that are unknown the transformer will return a variable length vector.
696+
// This is the work around in absence of reshape transformer.
701697
int[] dims = shape.NumDimensions > 0 ? shape.ToIntArray().Skip(shape[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
698+
for (int j = 0; j < dims.Length; j++)
699+
dims[j] = dims[j] == -1 ? 0 : dims[j];
702700
var type = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
703701
outputTypes[i] = new VectorType(type, dims);
704702
tfOutputTypes[i] = tfOutput.OutputType;
@@ -837,14 +835,22 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
837835
var originalShape = _parent.TFInputShapes[i];
838836
var shape = originalShape.ToIntArray();
839837

840-
var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray();
838+
var colTypeDims = vecType.Dimensions.Select(dim => (long)dim).ToArray();
841839
if (shape == null)
842840
_fullySpecifiedShapes[i] = new TFShape(colTypeDims);
843-
else if (vecType.Dimensions.Length == 1)
841+
else
844842
{
845843
// If the column is one dimension we make sure that the total size of the TF shape matches.
846844
// Compute the total size of the known dimensions of the shape.
847-
int valCount = shape.Where(x => x > 0).Aggregate((x, y) => x * y);
845+
int valCount = 1;
846+
int numOfUnkDim = 0;
847+
foreach (var s in shape)
848+
{
849+
if (s > 0)
850+
valCount *= s;
851+
else
852+
numOfUnkDim++;
853+
}
848854
// The column length should be divisible by this, so that the other dimensions can be integral.
849855
int typeValueCount = type.GetValueCount();
850856
if (typeValueCount % valCount != 0)
@@ -853,8 +859,8 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
853859
// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
854860
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
855861
// d such that d*d*3 is equal to the length of the input column.
856-
var d = originalShape.NumDimensions > 2 ? Math.Pow(typeValueCount / valCount, 1.0 / (originalShape.NumDimensions - 2)) : 1;
857-
if (originalShape.NumDimensions > 2 && d - (int)d != 0)
862+
var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0;
863+
if (d - (int)d != 0)
858864
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
859865

860866
// Fill in the unknown dimensions.
@@ -863,17 +869,6 @@ public Mapper(TensorFlowTransformer parent, Schema inputSchema) :
863869
l[ishape] = originalShape[ishape] == -1 ? (int)d : originalShape[ishape];
864870
_fullySpecifiedShapes[i] = new TFShape(l);
865871
}
866-
else
867-
{
868-
if (shape.Select((dim, j) => dim != -1 && dim != colTypeDims[j]).Any(b => b))
869-
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is {vecType.ToString()}.");
870-
871-
// Fill in the unknown dimensions.
872-
var l = new long[originalShape.NumDimensions];
873-
for (int ishape = 0; ishape < originalShape.NumDimensions; ishape++)
874-
l[ishape] = originalShape[ishape] == -1 ? colTypeDims[ishape] : originalShape[ishape];
875-
_fullySpecifiedShapes[i] = new TFShape(l);
876-
}
877872
}
878873
}
879874

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,117 @@ public void TensorFlowTransformMatrixMultiplicationTest()
7474
}
7575
}
7676

77+
private class ShapeData
78+
{
79+
// Data will be passed as 1-D vector.
80+
// Intended data shape [5], model shape [None]
81+
[VectorType(5)]
82+
public float[] OneDim;
83+
84+
// Data will be passed as flat vector.
85+
// Intended data shape [2,2], model shape [2, None]
86+
[VectorType(4)]
87+
public float[] TwoDim;
88+
89+
// Data will be passed as 3-D vector.
90+
// Intended data shape [1, 2, 2], model shape [1, None, 2]
91+
[VectorType(1, 2, 2)]
92+
public float[] ThreeDim;
93+
94+
// Data will be passed as flat vector.
95+
// Intended data shape [1, 2, 2, 3], model shape [1, None, None, 3]
96+
[VectorType(12)]
97+
public float[] FourDim;
98+
99+
// Data will be passed as 4-D vector.
100+
// Intended data shape [2, 2, 2, 2], model shape [2, 2, 2, 2]
101+
[VectorType(2, 2, 2, 2)]
102+
public float[] FourDimKnown;
103+
}
104+
105+
private List<ShapeData> GetShapeData()
106+
{
107+
return new List<ShapeData>(new ShapeData[] {
108+
new ShapeData() { OneDim = new[] { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f },
109+
TwoDim = new[] { 1.0f, 2.0f, 3.0f, 4.0f },
110+
ThreeDim = new[] { 11.0f, 12.0f, 13.0f, 14.0f },
111+
FourDim = new[]{ 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f,
112+
27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f },
113+
FourDimKnown = new[]{ 41.0f , 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
114+
49.0f , 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f}
115+
},
116+
new ShapeData() { OneDim = new[] { 100.1f, 100.2f, 100.3f, 100.4f, 100.5f },
117+
TwoDim = new[] { 101.0f, 102.0f, 103.0f, 104.0f },
118+
ThreeDim = new[] { 111.0f, 112.0f, 113.0f, 114.0f },
119+
FourDim = new[]{ 121.0f, 122.0f, 123.0f, 124.0f, 125.0f, 126.0f,
120+
127.0f, 128.0f, 129.0f, 130.0f, 131.0f, 132.0f},
121+
FourDimKnown = new[]{ 141.0f , 142.0f, 143.0f, 144.0f, 145.0f, 146.0f, 147.0f, 148.0f,
122+
149.0f , 150.0f, 151.0f, 152.0f, 153.0f, 154.0f, 155.0f, 156.0f }
123+
}
124+
});
125+
}
126+
127+
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only
128+
public void TensorFlowTransformInputShapeTest()
129+
{
130+
var modelLocation = "model_shape_test";
131+
var mlContext = new MLContext(seed: 1, conc: 1);
132+
var data = GetShapeData();
133+
// Pipeline
134+
var loader = mlContext.Data.ReadFromEnumerable(data);
135+
var inputs = new string[] { "OneDim", "TwoDim", "ThreeDim", "FourDim", "FourDimKnown" };
136+
var outputs = new string[] { "o_OneDim", "o_TwoDim", "o_ThreeDim", "o_FourDim", "o_FourDimKnown" };
137+
138+
var trans = mlContext.Transforms.ScoreTensorFlowModel(modelLocation, outputs, inputs).Fit(loader).Transform(loader);
139+
140+
using (var cursor = trans.GetRowCursorForAllColumns())
141+
{
142+
int outColIndex = 5;
143+
var oneDimgetter = cursor.GetGetter<VBuffer<float>>(outColIndex);
144+
var twoDimgetter = cursor.GetGetter<VBuffer<float>>(outColIndex + 1);
145+
var threeDimgetter = cursor.GetGetter<VBuffer<float>>(outColIndex + 2);
146+
var fourDimgetter = cursor.GetGetter<VBuffer<float>>(outColIndex + 3);
147+
var fourDimKnowngetter = cursor.GetGetter<VBuffer<float>>(outColIndex + 4);
148+
149+
VBuffer<float> oneDim = default;
150+
VBuffer<float> twoDim = default;
151+
VBuffer<float> threeDim = default;
152+
VBuffer<float> fourDim = default;
153+
VBuffer<float> fourDimKnown = default;
154+
foreach (var sample in data)
155+
{
156+
Assert.True(cursor.MoveNext());
157+
158+
oneDimgetter(ref oneDim);
159+
twoDimgetter(ref twoDim);
160+
threeDimgetter(ref threeDim);
161+
fourDimgetter(ref fourDim);
162+
fourDimKnowngetter(ref fourDimKnown);
163+
164+
var oneDimValues = oneDim.GetValues();
165+
Assert.Equal(sample.OneDim.Length, oneDimValues.Length);
166+
Assert.True(oneDimValues.SequenceEqual(sample.OneDim));
167+
168+
var twoDimValues = twoDim.GetValues();
169+
Assert.Equal(sample.TwoDim.Length, twoDimValues.Length);
170+
Assert.True(twoDimValues.SequenceEqual(sample.TwoDim));
171+
172+
var threeDimValues = threeDim.GetValues();
173+
Assert.Equal(sample.ThreeDim.Length, threeDimValues.Length);
174+
Assert.True(threeDimValues.SequenceEqual(sample.ThreeDim));
175+
176+
var fourDimValues = fourDim.GetValues();
177+
Assert.Equal(sample.FourDim.Length, fourDimValues.Length);
178+
Assert.True(fourDimValues.SequenceEqual(sample.FourDim));
179+
180+
var fourDimKnownValues = fourDimKnown.GetValues();
181+
Assert.Equal(sample.FourDimKnown.Length, fourDimKnownValues.Length);
182+
Assert.True(fourDimKnownValues.SequenceEqual(sample.FourDimKnown));
183+
}
184+
Assert.False(cursor.MoveNext());
185+
}
186+
}
187+
77188
private class TypesData
78189
{
79190
[VectorType(2)]
@@ -142,7 +253,7 @@ public void TensorFlowTransformInputOutputTypesTest()
142253

143254
var loader = mlContext.Data.ReadFromEnumerable(data);
144255

145-
var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"};
256+
var inputs = new string[] { "f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8", "b" };
146257
var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" };
147258
var trans = mlContext.Transforms.ScoreTensorFlowModel(model_location, outputs, inputs).Fit(loader).Transform(loader); ;
148259

@@ -160,7 +271,7 @@ public void TensorFlowTransformInputOutputTypesTest()
160271
var u8getter = cursor.GetGetter<VBuffer<byte>>(20);
161272
var boolgetter = cursor.GetGetter<VBuffer<bool>>(21);
162273

163-
274+
164275
VBuffer<double> f64 = default;
165276
VBuffer<float> f32 = default;
166277
VBuffer<long> i64 = default;
@@ -449,7 +560,7 @@ public void TensorFlowTransformMNISTLRTrainingTest()
449560
ReTrain = true
450561
}))
451562
.Append(mlContext.Transforms.Concatenate("Features", "Prediction"))
452-
.Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel","Label", maxNumKeys: 10))
563+
.Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel", "Label", maxNumKeys: 10))
453564
.Append(mlContext.MulticlassClassification.Trainers.LightGbm("KeyLabel", "Features"));
454565

455566
var trainedModel = pipe.Fit(trainData);

0 commit comments

Comments
 (0)