Skip to content

Commit a41e6d9

Browse files
authored
Added support for inserting batch dimension in inputs in TensorFlow. (#2935)
1 parent 32017a3 commit a41e6d9

File tree

6 files changed

+104
-46
lines changed

6 files changed

+104
-46
lines changed

src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs

+15-12
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ private sealed class OutColumn : Vector<float>
1414
{
1515
public PipelineColumn Input { get; }
1616

17-
public OutColumn(Vector<float> input, string modelFile)
18-
: base(new Reconciler(modelFile), input)
17+
public OutColumn(Vector<float> input, string modelFile, bool addBatchDimensionInput)
18+
: base(new Reconciler(modelFile, addBatchDimensionInput), input)
1919
{
2020
Input = input;
2121
}
2222

23-
public OutColumn(Vector<float> input, TensorFlowModel tensorFlowModel)
24-
: base(new Reconciler(tensorFlowModel), input)
23+
public OutColumn(Vector<float> input, TensorFlowModel tensorFlowModel, bool addBatchDimensionInput)
24+
: base(new Reconciler(tensorFlowModel, addBatchDimensionInput), input)
2525
{
2626
Input = input;
2727
}
@@ -31,20 +31,23 @@ private sealed class Reconciler : EstimatorReconciler
3131
{
3232
private readonly string _modelFile;
3333
private readonly TensorFlowModel _tensorFlowModel;
34+
private readonly bool _addBatchDimensionInput;
3435

35-
public Reconciler(string modelFile)
36+
public Reconciler(string modelFile, bool addBatchDimensionInput)
3637
{
3738
Contracts.AssertNonEmpty(modelFile);
3839
_modelFile = modelFile;
3940
_tensorFlowModel = null;
41+
_addBatchDimensionInput = addBatchDimensionInput;
4042
}
4143

42-
public Reconciler(TensorFlowModel tensorFlowModel)
44+
public Reconciler(TensorFlowModel tensorFlowModel, bool addBatchDimensionInput)
4345
{
4446
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));
4547

4648
_modelFile = null;
4749
_tensorFlowModel = tensorFlowModel;
50+
_addBatchDimensionInput = addBatchDimensionInput;
4851
}
4952

5053
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
@@ -57,9 +60,9 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
5760

5861
var outCol = (OutColumn)toOutput[0];
5962
if (_modelFile == null)
60-
return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _tensorFlowModel);
63+
return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _tensorFlowModel, _addBatchDimensionInput);
6164
else
62-
return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile);
65+
return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile, _addBatchDimensionInput);
6366
}
6467
}
6568

@@ -70,22 +73,22 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
7073
/// Load the TensorFlow model from <paramref name="modelFile"/> and run it on the input column and extract one output column.
7174
/// The inputs and outputs are matched to TensorFlow graph nodes by name.
7275
/// </summary>
73-
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, string modelFile)
76+
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, string modelFile, bool addBatchDimensionInput = false)
7477
{
7578
Contracts.CheckValue(input, nameof(input));
7679
Contracts.CheckNonEmpty(modelFile, nameof(modelFile));
77-
return new OutColumn(input, modelFile);
80+
return new OutColumn(input, modelFile, addBatchDimensionInput);
7881
}
7982

8083
/// <summary>
8184
/// Run a TensorFlow model provided through <paramref name="tensorFlowModel"/> on the input column and extract one output column.
8285
/// The inputs and outputs are matched to TensorFlow graph nodes by name.
8386
/// </summary>
84-
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, TensorFlowModel tensorFlowModel)
87+
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, TensorFlowModel tensorFlowModel, bool addBatchDimensionInput = false)
8588
{
8689
Contracts.CheckValue(input, nameof(input));
8790
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));
88-
return new OutColumn(input, tensorFlowModel);
91+
return new OutColumn(input, tensorFlowModel, addBatchDimensionInput);
8992
}
9093
}
9194
}

src/Microsoft.ML.TensorFlow/TensorFlowModel.cs

+14-6
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,34 @@ public DataViewSchema GetInputSchema()
5555
/// </summary>
5656
/// <param name="inputColumnName"> The name of the model input.</param>
5757
/// <param name="outputColumnName">The name of the requested model output.</param>
58+
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
59+
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
5860
/// <example>
5961
/// <format type="text/markdown">
6062
/// <![CDATA[
6163
/// [!code-csharp[ScoreTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs)]
6264
/// ]]>
6365
/// </format>
6466
/// </example>
65-
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName)
66-
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this);
67+
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName, bool addBatchDimensionInput = false)
68+
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this, addBatchDimensionInput);
6769

6870
/// <summary>
6971
/// Scores a dataset using a pre-traiend TensorFlow model.
7072
/// </summary>
7173
/// <param name="inputColumnNames"> The names of the model inputs.</param>
7274
/// <param name="outputColumnNames">The names of the requested model outputs.</param>
75+
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
76+
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
7377
/// <example>
7478
/// <format type="text/markdown">
7579
/// <![CDATA[
7680
/// [!code-csharp[ScoreTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/ImageClassification.cs)]
7781
/// ]]>
7882
/// </format>
7983
/// </example>
80-
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames)
81-
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this);
84+
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false)
85+
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this, addBatchDimensionInput);
8286

8387
/// <summary>
8488
/// Retrain the TensorFlow model on new data.
@@ -97,6 +101,8 @@ public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, stri
97101
/// <param name="metricOperation">The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).</param>
98102
/// <param name="learningRateOperation">The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).</param>
99103
/// <param name="learningRate">Learning rate to use during optimization (Optional).</param>
104+
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
105+
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
100106
/// <remarks>
101107
/// The support for retraining is experimental.
102108
/// </remarks>
@@ -111,7 +117,8 @@ public TensorFlowEstimator RetrainTensorFlowModel(
111117
string lossOperation = null,
112118
string metricOperation = null,
113119
string learningRateOperation = null,
114-
float learningRate = 0.01f)
120+
float learningRate = 0.01f,
121+
bool addBatchDimensionInput = false)
115122
{
116123
var options = new TensorFlowEstimator.Options()
117124
{
@@ -127,7 +134,8 @@ public TensorFlowEstimator RetrainTensorFlowModel(
127134
LearningRateOperation = learningRateOperation,
128135
LearningRate = learningRate,
129136
BatchSize = batchSize,
130-
ReTrain = true
137+
ReTrain = true,
138+
AddBatchDimensionInputs = addBatchDimensionInput
131139
};
132140
return new TensorFlowEstimator(_env, options, this);
133141
}

src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public static class TensorflowCatalog
1414
{
1515
/// <summary>
1616
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
17-
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string)"/>.
17+
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string, bool)"/>.
1818
/// </summary>
1919
/// <param name="catalog">The transform's catalog.</param>
2020
/// <param name="modelLocation">Location of the TensorFlow model.</param>

0 commit comments

Comments
 (0)