Skip to content

Commit f5253d4

Browse files
committed
Merge branch 'main' of https://github.com/llvm/torch-mlir into canonicalization-for-aten-add-tensor-op
2 parents 3d3b1bd + a34dad2 commit f5253d4

File tree

16 files changed

+366
-17
lines changed

16 files changed

+366
-17
lines changed

docs/adding_a_shape_function.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ We will use the example of adding support for the `torch.aten.tanh` op.
2626
functions don't get outdated if Torch changes an operator signature.
2727

2828
3. Fill in the body of the shape function. Ideally this will just be a call into
29-
a helper function from `upstream_shape_helpers.py`. But in general, you will
30-
need to write the shape function and test it (see the comments about "Shape
31-
function testing infrastructure" in `shape_lib_gen.py`).
29+
a helper function from
30+
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
31+
But in general, you will need to write the shape function and test it (see
32+
the comments about "Shape function testing infrastructure" in
33+
`shape_lib_gen.py`). New shape functions should be added upstream following
34+
the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
35+
though it can be useful to iterate locally in `shape_lib_gen.py` first.
3236

3337
4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
3438
library. After this step happens, ideally everything "just works" and the

docs/shape_lib.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ Shape functions are defined as TorchScript-able Python functions in
2626
The signatures of the shape functions are systematically derived from Torch JIT
2727
operator registry (mainly by replacing `Tensor` with `List[int]` in the operator
2828
signatures). Most shape functions are expected to reuse the upstream helper
29-
functions in
30-
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py`.
29+
functions [`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
30+
and any new shape functions should be added there.
3131

3232
The `build_tools/update_shape_lib.sh` script invokes `shape_lib_gen.py` to
3333
generate an MLIR module containing the shape functions, which is currently
@@ -119,10 +119,3 @@ was based on the following goals:
119119
written, which are still a fairly large and non-trivial set.
120120

121121
- To make it as mechanical as possible to add a new shape function.
122-
123-
## TODO
124-
125-
We should develop a workflow with upstream to push our manually-authored shape
126-
functions to live and be tested there. We should also find a way to share with
127-
upstream the mapping between operators and their shape functions. We will be
128-
able to simplify this infrastructure quite a bit once that happens.

e2e_testing/torchscript/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,8 @@
160160
"BaddbmmWithBetaModule_basic",
161161
"BaddbmmBroadcast1DInputModule_basic",
162162
"BaddbmmBroadcast2DInputModule_basic",
163+
"NumpyTRank1Module_basic",
164+
"NumpyTRank2Module_basic",
165+
"NumpyTRankNStaticModule_basic",
166+
"NumpyTRankNDynamicModule_basic",
163167
}

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4150,6 +4150,29 @@ def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [
41504150
}];
41514151
}
41524152

4153+
def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [
4154+
AllowsTypeRefinement,
4155+
HasValueSemantics,
4156+
ReadOnly
4157+
]> {
4158+
let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`";
4159+
let arguments = (ins
4160+
AnyTorchTensorType:$self
4161+
);
4162+
let results = (outs
4163+
Torch_BoolType:$result
4164+
);
4165+
let hasCustomAssemblyFormat = 1;
4166+
let extraClassDefinition = [{
4167+
ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) {
4168+
return parseDefaultTorchOp(parser, result, 1, 1);
4169+
}
4170+
void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) {
4171+
printDefaultTorchOp(printer, *this, 1, 1);
4172+
}
4173+
}];
4174+
}
4175+
41534176
def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
41544177
AllowsTypeRefinement,
41554178
HasValueSemantics,
@@ -5980,6 +6003,28 @@ def Torch_AtenTOp : Torch_Op<"aten.t", [
59806003
}];
59816004
}
59826005

6006+
def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [
6007+
AllowsTypeRefinement,
6008+
ReadOnly
6009+
]> {
6010+
let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`";
6011+
let arguments = (ins
6012+
AnyTorchTensorType:$self
6013+
);
6014+
let results = (outs
6015+
AnyTorchTensorType:$result
6016+
);
6017+
let hasCustomAssemblyFormat = 1;
6018+
let extraClassDefinition = [{
6019+
ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) {
6020+
return parseDefaultTorchOp(parser, result, 1, 1);
6021+
}
6022+
void AtenNumpyTOp::print(OpAsmPrinter &printer) {
6023+
printDefaultTorchOp(printer, *this, 1, 1);
6024+
}
6025+
}];
6026+
}
6027+
59836028
def Torch_AtenFullOp : Torch_Op<"aten.full", [
59846029
AllowsTypeRefinement,
59856030
HasValueSemantics,

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
133133
sizes.push_back(embeddingDim);
134134
int64_t resultRank = sizes.size();
135135

136-
auto indicesTy = weight.getType().cast<RankedTensorType>();
136+
auto indicesTy = indices.getType().cast<RankedTensorType>();
137137
int64_t indicesRank = indicesTy.getRank();
138138
SmallVector<AffineExpr> indicesExprs;
139139
for (int i = 0; i < indicesRank; i++)

lib/Conversion/TorchToStd/TorchToStd.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
5050
};
5151
} // namespace
5252

53+
namespace {
54+
class ConvertAtenIsFloatingPointOp
55+
: public OpConversionPattern<AtenIsFloatingPointOp> {
56+
public:
57+
using OpConversionPattern::OpConversionPattern;
58+
LogicalResult
59+
matchAndRewrite(AtenIsFloatingPointOp op, OpAdaptor adaptor,
60+
ConversionPatternRewriter &rewriter) const override {
61+
auto tensorType = op.self().getType().cast<BaseTensorType>();
62+
bool result =
63+
tensorType.hasDtype() && tensorType.getDtype().isa<mlir::FloatType>();
64+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
65+
op, BoolAttr::get(getContext(), result));
66+
return success();
67+
}
68+
};
69+
} // namespace
70+
5371
namespace {
5472
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
5573
public:
@@ -301,6 +319,8 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
301319
RewritePatternSet patterns(context);
302320
target.addIllegalOp<AtenDimOp>();
303321
patterns.add<ConvertAtenDimOp>(typeConverter, context);
322+
target.addIllegalOp<AtenIsFloatingPointOp>();
323+
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
304324
target.addIllegalOp<RuntimeAssertOp>();
305325
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
306326
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp>();

lib/Conversion/Utils/Utils.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
2626
// TODO: Remove this check but use a separate verification pass to verify the
2727
// invariants expected by later passes.
2828
auto isValidLinalgType = [](Type type) {
29+
if (type.isa<NonValueTensorType>())
30+
return false;
2931
auto tensor = type.dyn_cast<ValueTensorType>();
3032
return !tensor ||
3133
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
@@ -242,15 +244,32 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
242244
return false;
243245
};
244246

245-
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
246-
dtype.isSignlessInteger(1)) {
247+
if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
247248
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
248249
mlir::emitError(loc)
249250
<< "unsupported byte, char or bool type for convertScalarToDtype "
250251
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
251252
return nullptr;
252253
}
253254

255+
// If the dtype is i1, i.e., a boolean type.
256+
if (dtype.isSignlessInteger(1)) {
257+
Type scalarType = scalar.getType();
258+
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
259+
if (scalarType.isa<mlir::FloatType>()) {
260+
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
261+
cstZero);
262+
} else if (scalarType.isa<mlir::IntegerType>()) {
263+
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
264+
cstZero);
265+
} else {
266+
mlir::emitError(loc)
267+
<< "unsupported scalar type for convertScalarToDtype " << scalarType
268+
<< "(scalar type) -> " << dtype << "(dtype)";
269+
return nullptr;
270+
}
271+
}
272+
254273
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
255274
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
256275
if (scalarFloat.getWidth() > dtypeFloat.getWidth())

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,29 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
19571957
};
19581958
} // namespace
19591959

1960+
namespace {
1961+
// Decompose `aten.numpy_T` op into `aten.permute` op.
1962+
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
1963+
using OpRewritePattern::OpRewritePattern;
1964+
LogicalResult matchAndRewrite(AtenNumpyTOp op,
1965+
PatternRewriter &rewriter) const override {
1966+
Location loc = op.getLoc();
1967+
Value self = op.self();
1968+
int64_t inputRank = getTensorRank(self);
1969+
1970+
SmallVector<Value> dimListElements;
1971+
for (int64_t i = inputRank - 1; i >= 0; i--)
1972+
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
1973+
loc, rewriter.getI64IntegerAttr(i)));
1974+
Value dimList = rewriter.create<PrimListConstructOp>(
1975+
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
1976+
dimListElements);
1977+
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
1978+
return success();
1979+
}
1980+
};
1981+
} // namespace
1982+
19601983
namespace {
19611984
class DecomposeComplexOpsPass
19621985
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -2102,6 +2125,8 @@ class DecomposeComplexOpsPass
21022125
target.addIllegalOp<AtenBaddbmmOp>();
21032126
patterns.add<DecomposeAtenFloorDivideOp>(context);
21042127
target.addIllegalOp<AtenFloorDivideOp>();
2128+
patterns.add<DecomposeAtenNumpyTOp>(context);
2129+
target.addIllegalOp<AtenNumpyTOp>();
21052130

21062131
if (failed(applyPartialConversion(getOperation(), target,
21072132
std::move(patterns)))) {

lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static bool isViewLikeOp(Operation *op) {
3737
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
3838
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
3939
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
40-
TensorStaticInfoCastOp, AtenToDtypeLayoutOp>(op);
40+
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp>(op);
4141
}
4242

4343
namespace {

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
642642
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
643643
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
644644
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
645-
PrimAbsScalarOp>(op)) {
645+
PrimAbsScalarOp, AtenNumpyTOp>(op)) {
646646
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
647647
}
648648

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5628,6 +5628,19 @@ module {
56285628
%0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>
56295629
return %0 : !torch.list<int>
56305630
}
5631+
func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list<int>) -> !torch.list<int> {
5632+
%int0 = torch.constant.int 0
5633+
%true = torch.constant.bool true
5634+
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
5635+
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
5636+
torch.prim.Loop %1, %true, init() {
5637+
^bb0(%arg1: !torch.int):
5638+
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
5639+
torch.aten.insert.t %0, %int0, %2 : !torch.list<int>, !torch.int, !torch.int
5640+
torch.prim.Loop.condition %true, iter()
5641+
} : (!torch.int, !torch.bool) -> ()
5642+
return %0 : !torch.list<int>
5643+
}
56315644
func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
56325645
%0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
56335646
return %0 : !torch.list<int>

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,12 @@ def aten〇transpose〇int(self: List[int], dim0: int, dim1: int) -> List[int]:
529529
def aten〇t(self: List[int]) -> List[int]:
530530
return upstream_shape_functions.transpose(self, 0, 1)
531531

532+
def aten〇numpy_T(self: List[int]) -> List[int]:
533+
result_shape: List[int] = []
534+
for i in self:
535+
result_shape.insert(0, i)
536+
return result_shape
537+
532538
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
533539
return upstream_shape_functions.matmul(self, other)
534540

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def emit_with_mutating_variants(key, **kwargs):
392392
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
393393
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
394394
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
395+
emit("aten::is_floating_point : (Tensor) -> (bool)")
395396
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
396397
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
397398
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
@@ -464,6 +465,7 @@ def emit_with_mutating_variants(key, **kwargs):
464465
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
465466
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
466467
emit("aten::t : (Tensor) -> (Tensor)")
468+
emit("aten::numpy_T : (Tensor) -> (Tensor)")
467469
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
468470
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
469471
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")

0 commit comments

Comments
 (0)