diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 52b1b28849..ff3a424f88 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -730,7 +730,7 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : var originalShape = _parent.TFInputShapes[i]; var shape = originalShape.ToIntArray(); - var colTypeDims = vecType.Dimensions.Select(dim => (long)dim).ToArray(); + var colTypeDims = vecType.Dimensions.Prepend(1).Select(dim => (long)dim).ToArray(); if (shape == null) _fullySpecifiedShapes[i] = new TFShape(colTypeDims); else