Skip to content

Commit 37f57a9

Browse files
author
Tanyo Kwok
authored
Delete ConvertAtenNativeLayerNormOp from TorchToLinalg (#1336)
The ConvertAtenNativeLayerNormOp is delete because we have decomposition already see #1332
1 parent e6528f7 commit 37f57a9

File tree

1 file changed

+0
-267
lines changed

1 file changed

+0
-267
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 0 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,271 +1257,6 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
12571257
};
12581258
} // namespace
12591259

1260-
// For layernorm, the mean and standard-deviation are calculated separately over
1261-
// the last certain number dimensions which have to be of the shape specified by
1262-
// normalized_shape.
1263-
//
1264-
// The shapes of different parts are as the following:
1265-
// +-------------------+--------------------+
1266-
// | meanAndVarShape | normalizedShape |
1267-
// +-------------------+---------------------
1268-
// <------------+ inputShape +-------------->
1269-
// There are the following steps:
1270-
// Step 1. Check if all the arguments meet the requirements.
1271-
// Step 2. Common parts to be used for getting mean and var.
1272-
// This includes elements count, affineMap and iteratorTypes.
1273-
// Step 3. Get mean.
1274-
// Step 4. Get rSTD.
1275-
// Step 5. Get layernorm.
1276-
namespace {
1277-
class ConvertAtenNativeLayerNormOp
1278-
: public OpConversionPattern<AtenNativeLayerNormOp> {
1279-
public:
1280-
using OpConversionPattern::OpConversionPattern;
1281-
LogicalResult
1282-
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
1283-
ConversionPatternRewriter &rewriter) const override {
1284-
MLIRContext *context = op->getContext();
1285-
Location loc = op->getLoc();
1286-
Value input = adaptor.input();
1287-
Value weight = adaptor.weight();
1288-
Value bias = adaptor.bias();
1289-
Value eps = adaptor.eps();
1290-
Value normalizedShape = op.normalized_shape();
1291-
1292-
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1293-
return failure();
1294-
1295-
// TODO: Handle the None cases for the optional parameters:
1296-
// weight, bias.
1297-
if (failed(checkNotNone(rewriter, op, weight)) ||
1298-
failed(checkNotNone(rewriter, op, bias)))
1299-
return failure();
1300-
1301-
auto inputType = input.getType().cast<RankedTensorType>();
1302-
auto weightType = weight.getType().cast<RankedTensorType>();
1303-
auto biasType = bias.getType().cast<RankedTensorType>();
1304-
int64_t inputRank = inputType.getRank();
1305-
Type elemTy = inputType.getElementType();
1306-
1307-
// Step 1. Check if all the arguments meet the requirements.
1308-
SmallVector<Value> normalizedShapeSizesTorchInt;
1309-
if (!getListConstructElements(normalizedShape,
1310-
normalizedShapeSizesTorchInt)) {
1311-
return rewriter.notifyMatchFailure(op,
1312-
"Unimplemented normalized_shape not"
1313-
"constructed from ListConstruct");
1314-
}
1315-
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
1316-
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
1317-
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
1318-
if (weightType.getRank() != normalizedShapeRank ||
1319-
biasType.getRank() != normalizedShapeRank ||
1320-
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
1321-
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
1322-
"normalized shape not compatible");
1323-
1324-
// Check all the dimensions match the normalized_shape
1325-
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
1326-
for (auto en : enumerate((normalizedShapeSizesInt))) {
1327-
auto index = en.index();
1328-
auto inputDim =
1329-
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
1330-
auto weightDim = getDimOp(rewriter, loc, weight, index);
1331-
auto biasDim = getDimOp(rewriter, loc, bias, index);
1332-
1333-
auto expectedSize = en.value();
1334-
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
1335-
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
1336-
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
1337-
}
1338-
1339-
// Get iterator types for input shape.
1340-
SmallVector<StringRef> normalizedShapeIteratorTypes(
1341-
normalizedShapeRank, getReductionIteratorTypeName());
1342-
SmallVector<StringRef> meanAndVarIterationTypes(
1343-
meanAndVarShapeRank, getParallelIteratorTypeName());
1344-
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
1345-
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
1346-
1347-
// Step 2. Common parts to be used for getting mean and var.
1348-
1349-
// Get sizes and affineMaps needed for mean and var.
1350-
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
1351-
SmallVector<AffineExpr> meanAndVarShapeExprs;
1352-
for (int i = 0; i < meanAndVarShapeRank; i++)
1353-
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
1354-
auto meanAndVarShapeAffineMap = AffineMap::get(
1355-
/*dimCount=*/inputRank,
1356-
/*symbolCount=*/0, meanAndVarShapeExprs, context);
1357-
SmallVector<Value> meanAndVarShapeSizes =
1358-
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
1359-
1360-
// Get number of elements to be used for calculating mean and var.
1361-
Value elemCnts = normalizedShapeSizesInt[0];
1362-
for (int i = 1; i < normalizedShapeRank; i++) {
1363-
elemCnts = rewriter.create<arith::MulIOp>(loc, elemCnts,
1364-
normalizedShapeSizesInt[i]);
1365-
}
1366-
Value elemCntsFloat =
1367-
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);
1368-
1369-
// Helper to calculate mean and var.
1370-
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
1371-
SmallVector<AffineMap> indexingMaps(
1372-
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
1373-
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
1374-
loc, meanAndVarShapeSizes, elemTy);
1375-
return rewriter
1376-
.create<linalg::GenericOp>(
1377-
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
1378-
/*indexingMaps=*/indexingMaps,
1379-
/*iteratorTypes=*/meanAndVarIterationTypes,
1380-
[&](OpBuilder &b, Location loc, ValueRange args) {
1381-
Value sumOrSqureSum = args[0];
1382-
Value result =
1383-
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
1384-
b.create<linalg::YieldOp>(loc, result);
1385-
})
1386-
.getResult(0);
1387-
};
1388-
1389-
// Step 3. Get mean.
1390-
1391-
// Get sum to be used for calculating mean.
1392-
SmallVector<AffineMap, 2> sumIndexingMaps = {
1393-
inputShapeAffineMap, // input
1394-
meanAndVarShapeAffineMap, // output
1395-
};
1396-
auto initSumTensor =
1397-
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
1398-
Value sum = rewriter
1399-
.create<linalg::GenericOp>(
1400-
loc, initSumTensor.getType(), input, initSumTensor,
1401-
/*indexingMaps=*/sumIndexingMaps,
1402-
/*iteratorTypes=*/inputShapeIteratorTypes,
1403-
[&](OpBuilder &b, Location loc, ValueRange args) {
1404-
Value input = args[0], sum = args[1];
1405-
Value result =
1406-
rewriter.create<arith::AddFOp>(loc, sum, input);
1407-
b.create<linalg::YieldOp>(loc, result);
1408-
})
1409-
.getResult(0);
1410-
Value mean = genMeanOrVarCalculation(sum);
1411-
1412-
// Step 4. Get rSTD.
1413-
1414-
// Calculate squareSum for the layer.
1415-
SmallVector<AffineMap> squareSumIndexingMaps{
1416-
inputShapeAffineMap,
1417-
meanAndVarShapeAffineMap,
1418-
meanAndVarShapeAffineMap,
1419-
};
1420-
auto initSquareSumTensor =
1421-
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
1422-
Value squareSum =
1423-
rewriter
1424-
.create<linalg::GenericOp>(
1425-
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
1426-
initSquareSumTensor,
1427-
/*indexingMaps=*/squareSumIndexingMaps,
1428-
/*iteratorTypes=*/inputShapeIteratorTypes,
1429-
[&](OpBuilder &b, Location loc, ValueRange args) {
1430-
Value input = args[0], mean = args[1], squareSum = args[2];
1431-
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
1432-
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
1433-
Value result =
1434-
rewriter.create<arith::AddFOp>(loc, squareSum, square);
1435-
b.create<linalg::YieldOp>(loc, result);
1436-
})
1437-
.getResult(0);
1438-
Value var = genMeanOrVarCalculation(squareSum);
1439-
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
1440-
loc, meanAndVarShapeSizes, elemTy);
1441-
SmallVector<AffineMap> rSTDIndexingMap(
1442-
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
1443-
1444-
Value rSTD = rewriter
1445-
.create<linalg::GenericOp>(
1446-
loc, rSTDTensor.getType(), var, rSTDTensor,
1447-
rSTDIndexingMap, meanAndVarIterationTypes,
1448-
[&](OpBuilder &b, Location loc, ValueRange args) {
1449-
Value result =
1450-
calculateRSTD(b, loc, elemTy, eps, args[0]);
1451-
b.create<linalg::YieldOp>(loc, result);
1452-
})
1453-
.getResult(0);
1454-
1455-
// Step 5. Get layernorm.
1456-
1457-
// Get affineMap for normalized shape.
1458-
SmallVector<AffineExpr> normalizedShapeExprs;
1459-
for (int i = meanAndVarShapeRank; i < inputRank; i++)
1460-
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
1461-
auto normalizedShapeAffineMap = AffineMap::get(
1462-
/*dimCount=*/inputRank,
1463-
/*symbolCount=*/0, normalizedShapeExprs, context);
1464-
auto inputSizes = getTensorSizes(rewriter, loc, input);
1465-
Value initLayerNormTensor =
1466-
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
1467-
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
1468-
indexingMaps.resize(3, meanAndVarShapeAffineMap);
1469-
indexingMaps.resize(5, normalizedShapeAffineMap);
1470-
indexingMaps.push_back(inputShapeAffineMap);
1471-
SmallVector<StringRef> layerNormIterationTypes(
1472-
inputRank, getParallelIteratorTypeName());
1473-
Value layerNorm =
1474-
rewriter
1475-
.create<linalg::GenericOp>(
1476-
loc, initLayerNormTensor.getType(),
1477-
ValueRange{input, mean, rSTD, weight, bias},
1478-
initLayerNormTensor,
1479-
/*indexingMaps=*/indexingMaps,
1480-
/*iteratorTypes=*/layerNormIterationTypes,
1481-
[&](OpBuilder &b, Location loc, ValueRange args) {
1482-
Value input = args[0], mean = args[1], rSTD = args[2],
1483-
weight = args[3], bias = args[4];
1484-
Value result =
1485-
createLinalgPayloadCalculationForNormOpsWithRSTD(
1486-
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
1487-
b.create<linalg::YieldOp>(loc, result);
1488-
})
1489-
.getResult(0);
1490-
SmallVector<int64_t> expandShape(inputRank, 1);
1491-
for (int i = 0; i < meanAndVarShapeRank; i++) {
1492-
// `mean` and `rstd` are not yet casted, so they will be having dynamic
1493-
// shape. Hence to match them, for each dimension corresponding to `mean`
1494-
// or `rstd` assign -1.
1495-
expandShape[i] = -1;
1496-
}
1497-
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
1498-
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
1499-
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
1500-
reassociation[i].push_back(i);
1501-
if (i == meanAndVarShapeRank - 1) {
1502-
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
1503-
reassociation[i].push_back(i + j + 1);
1504-
}
1505-
}
1506-
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
1507-
loc, expandShapeType, mean, reassociation);
1508-
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
1509-
loc, expandShapeType, rSTD, reassociation);
1510-
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
1511-
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
1512-
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
1513-
Value layerNorm_ =
1514-
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
1515-
Value mean_ =
1516-
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
1517-
Value var_ =
1518-
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
1519-
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
1520-
return success();
1521-
}
1522-
};
1523-
} // namespace
1524-
15251260
namespace {
15261261
class ConvertAtenNllLossBackwardOp
15271262
: public OpConversionPattern<AtenNllLossBackwardOp> {
@@ -1728,8 +1463,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
17281463
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
17291464
target.addIllegalOp<AtenBatchNormOp>();
17301465
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
1731-
target.addIllegalOp<AtenNativeLayerNormOp>();
1732-
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
17331466
target.addIllegalOp<AtenNllLossBackwardOp>();
17341467
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
17351468
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);

0 commit comments

Comments
 (0)