Skip to content

Commit 53fbe7b

Browse files
committed
Add cases for view, type_as, and masked_fill
1 parent 37f57a9 commit 53fbe7b

File tree

8 files changed

+153
-2
lines changed

8 files changed

+153
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6133,6 +6133,7 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
61336133
printDefaultTorchOp(printer, *this, 2, 1);
61346134
}
61356135
}];
6136+
let hasFolder = 1;
61366137
}
61376138

61386139
def Torch_AtenViewOp : Torch_Op<"aten.view", [
@@ -6157,6 +6158,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
61576158
}
61586159
}];
61596160
let hasFolder = 1;
6161+
let hasCanonicalizer = 1;
61606162
}
61616163

61626164
def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,42 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
301301
op, "desired size list length mismatches with the result type rank");
302302
}
303303

304+
// Case where all sizes (input and output) are statically known
305+
SmallVector<int64_t> outputSizeIntList;
306+
if (matchPattern(op.size(), m_TorchConstantIntList(outputSizeIntList)) &&
307+
llvm::none_of(outputSizeIntList,
308+
[](int64_t x) { return x == kUnknownSize; }) &&
309+
inputType.hasStaticShape()) {
310+
311+
int64_t midSize = 1;
312+
for (auto outSize : outputSizeIntList)
313+
midSize *= outSize;
314+
auto midType =
315+
RankedTensorType::get({midSize}, resultType.getElementType());
316+
auto outType =
317+
RankedTensorType::get(outputSizeIntList, resultType.getElementType());
318+
int64_t outputRank = outputSizeIntList.size();
319+
320+
SmallVector<ReassociationIndices> flatten(1);
321+
for (auto i = 0; i < inputRank; i++)
322+
flatten[0].push_back(i);
323+
SmallVector<ReassociationIndices> unflatten(1);
324+
for (auto i = 0; i < outputRank; i++)
325+
unflatten[0].push_back(i);
326+
327+
Value flattened =
328+
(inputRank <= 1 ? input
329+
: rewriter.create<tensor::CollapseShapeOp>(
330+
loc, midType, input, flatten));
331+
Value result =
332+
(outputRank <= 1 ? flattened
333+
: rewriter.create<tensor::ExpandShapeOp>(
334+
loc, outType, flattened, unflatten));
335+
336+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
337+
return success();
338+
}
339+
304340
// Currently, we only handle the cases where each dimension is either
305341
// being expanded or collapsed. We do not handle cases where it's neither
306342
// collapsing nor expanding like view of [2,3] for 3x2 tensor.
@@ -319,6 +355,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
319355
// collapsed. Note this may technically not always be true.
320356
// TODO: think of a way better way to at least detect when this assumption
321357
// is violated for the cases of dynamic dimensions.
358+
322359
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
323360
SmallVector<ReassociationIndices> unchangedDims;
324361
llvm::Optional<int64_t> inferredDimension;

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
892892

893893
Value input = payloadArgs[0];
894894
Value mask = payloadArgs[1];
895+
if (mask.getType().isa<mlir::FloatType>())
896+
mask = b.create<arith::FPToUIOp>(
897+
loc, IntegerType::get(op->getContext(), 1), mask);
895898
Value fillValue = convertScalarToDtype(b, loc, adaptor.value(), dtype);
896899

897900
return b.create<arith::SelectOp>(loc, mask, fillValue, input);

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,22 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
673673
return nullptr;
674674
}
675675

676+
//===----------------------------------------------------------------------===//
677+
// AtenTypeAsOp
678+
//===----------------------------------------------------------------------===//
679+
680+
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
681+
Value input = getOperand(0);
682+
Value output = getOperand(1);
683+
Type inType = input.getType();
684+
Type outType = output.getType();
685+
686+
if (inType == outType)
687+
return input;
688+
689+
return nullptr;
690+
}
691+
676692
//===----------------------------------------------------------------------===//
677693
// AtenToDtypeOp
678694
//===----------------------------------------------------------------------===//
@@ -780,6 +796,36 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
780796
return getOperand(0);
781797
}
782798

799+
void AtenViewOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
800+
MLIRContext *context) {
801+
patterns.add(+[](AtenViewOp op, PatternRewriter &rewriter) {
802+
Location loc = op->getLoc();
803+
auto outType = op.getType().dyn_cast<BaseTensorType>();
804+
Value size = op.size();
805+
if (!outType || !outType.hasSizes())
806+
return failure();
807+
// When all result sizes are statically known, replace the sizes argument
808+
// with a constant list
809+
auto staticSizes = outType.getSizes();
810+
SmallVector<int64_t> constSizeList;
811+
if (!matchPattern(size, m_TorchConstantIntList(constSizeList)) &&
812+
llvm::none_of(staticSizes,
813+
[](int64_t x) { return x == kUnknownSize; })) {
814+
SmallVector<Value> staticSizeValues;
815+
for (auto staticSize : staticSizes)
816+
staticSizeValues.push_back(rewriter.create<ConstantIntOp>(
817+
loc, rewriter.getI64IntegerAttr(staticSize)));
818+
Value argList = rewriter.create<PrimListConstructOp>(
819+
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
820+
staticSizeValues);
821+
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.self(),
822+
argList);
823+
return success();
824+
}
825+
return failure();
826+
});
827+
}
828+
783829
//===----------------------------------------------------------------------===//
784830
// AtenDimOp
785831
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def emit_with_mutating_variants(key, **kwargs):
464464
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
465465
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
466466
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
467-
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
468-
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
467+
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True)
468+
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True, has_canonicalizer=True)
469469
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
470470
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
471471
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")

python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,3 +1462,26 @@ def forward(self, x, mask, value):
14621462
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
14631463
module.forward(tu.randint(2, 3, low=-10, high=10),
14641464
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
1465+
1466+
1467+
class MaskedFillTensorFloatMaskModule(torch.nn.Module):
1468+
1469+
def __init__(self):
1470+
super().__init__()
1471+
1472+
@export
1473+
@annotate_args([
1474+
None,
1475+
([-1, -1], torch.int64, True),
1476+
([-1, -1], torch.float, True),
1477+
])
1478+
def forward(self, x, mask):
1479+
return torch.ops.aten.masked_fill(x, mask, value=0.1)
1480+
1481+
1482+
@register_test_case(module_factory=lambda: MaskedFillTensorFloatMaskModule())
1483+
def MaskedFillTensorFloatMaskModule_basic(module, tu: TestUtils):
1484+
mask = tu.randint(2, 3, low=0, high=1).to(torch.float)
1485+
module.forward(tu.randint(2, 3, low=-10, high=10),
1486+
mask)
1487+

python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,25 @@ def View1DFoldModule_basic(module, tu: TestUtils):
259259

260260
# ==============================================================================
261261

262+
class ViewStaticModule(torch.nn.Module):
263+
def __init__(self):
264+
super().__init__()
265+
266+
@export
267+
@annotate_args([
268+
None,
269+
([5, 12, 3, 8], torch.float32, True),
270+
])
271+
272+
def forward(self, a):
273+
return a.view(6, 5, 6, 8)
274+
275+
@register_test_case(module_factory=lambda: ViewStaticModule())
276+
def ViewStaticModule_basic(module, tu: TestUtils):
277+
module.forward(tu.rand(5, 12, 3, 8))
278+
279+
# ==============================================================================
280+
262281
class ViewCollapseInferredDimModule(torch.nn.Module):
263282
def __init__(self):
264283
super().__init__()

python/torch_mlir_e2e_test/test_suite/type_conversion.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,24 @@ def forward(self, x):
214214
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
215215
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
216216
module.forward(tu.rand(3, 5))
217+
218+
class TypeAsSameModule(torch.nn.Module):
219+
220+
def __init__(self):
221+
super().__init__()
222+
223+
@export
224+
@annotate_args([
225+
None,
226+
([-1, -1], torch.float32, True),
227+
([-1, -1], torch.float32, True),
228+
])
229+
def forward(self, x, y):
230+
return torch.ops.aten.type_as(x,
231+
y)
232+
233+
234+
235+
@register_test_case(module_factory=lambda: TypeAsSameModule())
236+
def TypeAsSameModule_basic(module, tu: TestUtils):
237+
module.forward(tu.rand(3, 5), tu.rand(3, 5))

0 commit comments

Comments
 (0)