Skip to content

Commit 819f293

Browse files
author
Prashant Kumar
committed
Decompose aten.silu op
Decomposition of aten.silu.op is added as silu(x) = x * sigmoid(x).
1 parent ddd45d6 commit 819f293

File tree

7 files changed

+79
-1
lines changed

7 files changed

+79
-1
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,3 +1345,21 @@ def forward(self, x):
13451345
def HardswishRandomModule_basic(module, tu: TestUtils):
13461346
module.forward(tu.rand(128, 128, low=-10, high=10))
13471347

1348+
# ==============================================================================
1349+
1350+
class SiluModule(torch.nn.Module):
1351+
def __init__(self):
1352+
super().__init__()
1353+
1354+
@export
1355+
@annotate_args([
1356+
None,
1357+
([-1, -1], torch.float32, True),
1358+
])
1359+
def forward(self, x):
1360+
return torch.ops.aten.silu(x)
1361+
1362+
1363+
@register_test_case(module_factory=lambda: SiluModule())
1364+
def SiluModule_basic(module, tu: TestUtils):
1365+
module.forward(tu.rand(128, 128, low=-10, high=10))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,5 @@
129129
"NewOnesModuleFloat2D_basic",
130130
"NewOnesModuleFloat3D_basic",
131131
"NewOnesModuleFalsePinMemory_basic",
132+
"SiluModule_basic",
132133
}

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,34 @@ def Torch_AtenHardswish_Op : Torch_Op<"aten.hardswish_", [
214214
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
215215
}
216216

217+
def Torch_AtenSiluOp : Torch_Op<"aten.silu", [
218+
AllowsTypeRefinement,
219+
HasValueSemantics
220+
]> {
221+
let summary = "Generated op for `aten::silu : (Tensor) -> (Tensor)`";
222+
let arguments = (ins
223+
AnyTorchTensorType:$self
224+
);
225+
let results = (outs
226+
AnyTorchTensorType:$result
227+
);
228+
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
229+
}
230+
231+
def Torch_AtenSilu_Op : Torch_Op<"aten.silu_", [
232+
IsTrailingUnderscoreInplaceVariant,
233+
AllowsTypeRefinement
234+
]> {
235+
let summary = "Generated op for `aten::silu_ : (Tensor) -> (Tensor)`";
236+
let arguments = (ins
237+
AnyTorchTensorType:$self
238+
);
239+
let results = (outs
240+
AnyTorchTensorType:$result
241+
);
242+
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
243+
}
244+
217245
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
218246
AllowsTypeRefinement,
219247
HasValueSemantics

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,23 @@ class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
740740
};
741741
} // namespace
742742

743+
// Silu(x) = sigmoid(x) * x
744+
namespace {
745+
class DecomposeAtenSiluOp : public OpRewritePattern<AtenSiluOp> {
746+
public:
747+
using OpRewritePattern::OpRewritePattern;
748+
LogicalResult matchAndRewrite(AtenSiluOp op,
749+
PatternRewriter &rewriter) const override {
750+
Value self = op.self();
751+
Value sigmoid =
752+
rewriter.create<AtenSigmoidOp>(op.getLoc(), op.getType(), self);
753+
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), sigmoid,
754+
self);
755+
return success();
756+
}
757+
};
758+
} // namespace
759+
743760
// Decompose aten.var into: sum(square(x - mean))/(numTensorElements-1)
744761
// for unbiased and mean(square(x - mean)) for biased case.
745762
namespace {
@@ -1318,6 +1335,8 @@ class DecomposeComplexOpsPass
13181335
target.addIllegalOp<AtenHardsigmoidOp>();
13191336
patterns.add<DecomposeAtenHardswishOp>(context);
13201337
target.addIllegalOp<AtenHardswishOp>();
1338+
patterns.add<DecomposeAtenSiluOp>(context);
1339+
target.addIllegalOp<AtenSiluOp>();
13211340
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
13221341
context);
13231342
target.addIllegalOp<AtenNewZerosOp>();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
230230
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
231231
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
232232
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
233-
AtenHardsigmoidOp, AtenHardswishOp>(op)) {
233+
AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp>(op)) {
234234
return getLatticeElement(op->getResult(0)).join(*operands[0]);
235235
}
236236

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs):
452452
"aten::sigmoid : (Tensor) -> (Tensor)",
453453
"aten::hardsigmoid : (Tensor) -> (Tensor)",
454454
"aten::hardswish : (Tensor) -> (Tensor)",
455+
"aten::silu : (Tensor) -> (Tensor)",
455456
"aten::sin : (Tensor) -> (Tensor)",
456457
"aten::exp : (Tensor) -> (Tensor)",
457458
"aten::cos : (Tensor) -> (Tensor)",

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,3 +509,14 @@ func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[
509509
%1 = torch.aten.new_ones %arg0, %0, %none, %none, %none, %none : !torch.vtensor<[?,?],si64>, !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
510510
return %1 : !torch.vtensor<[3,4],si64>
511511
}
512+
513+
// -----
514+
// CHECK-LABEL: func @torch.aten.silu(
515+
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
516+
// CHECK: %[[SIGMOID:.*]] = torch.aten.sigmoid %[[INP]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor
517+
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[SIGMOID]], %[[INP]] : !torch.vtensor, !torch.vtensor<[?,?],f32> -> !torch.vtensor
518+
// CHECK: return %[[MUL]] : !torch.vtensor
519+
func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.vtensor {
520+
%0 = torch.aten.silu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
521+
return %0 : !torch.vtensor
522+
}

0 commit comments

Comments
 (0)