@@ -1257,271 +1257,6 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
1257
1257
};
1258
1258
} // namespace
1259
1259
1260
- // For layernorm, the mean and standard-deviation are calculated separately over
1261
- // the last certain number dimensions which have to be of the shape specified by
1262
- // normalized_shape.
1263
- //
1264
- // The shapes of different parts are as the following:
1265
- // +-------------------+--------------------+
1266
- // | meanAndVarShape | normalizedShape |
1267
- // +-------------------+---------------------
1268
- // <------------+ inputShape +-------------->
1269
- // There are the following steps:
1270
- // Step 1. Check if all the arguments meet the requirements.
1271
- // Step 2. Common parts to be used for getting mean and var.
1272
- // This includes elements count, affineMap and iteratorTypes.
1273
- // Step 3. Get mean.
1274
- // Step 4. Get rSTD.
1275
- // Step 5. Get layernorm.
1276
- namespace {
1277
- class ConvertAtenNativeLayerNormOp
1278
- : public OpConversionPattern<AtenNativeLayerNormOp> {
1279
- public:
1280
- using OpConversionPattern::OpConversionPattern;
1281
- LogicalResult
1282
- matchAndRewrite (AtenNativeLayerNormOp op, OpAdaptor adaptor,
1283
- ConversionPatternRewriter &rewriter) const override {
1284
- MLIRContext *context = op->getContext ();
1285
- Location loc = op->getLoc ();
1286
- Value input = adaptor.input ();
1287
- Value weight = adaptor.weight ();
1288
- Value bias = adaptor.bias ();
1289
- Value eps = adaptor.eps ();
1290
- Value normalizedShape = op.normalized_shape ();
1291
-
1292
- if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1293
- return failure ();
1294
-
1295
- // TODO: Handle the None cases for the optional parameters:
1296
- // weight, bias.
1297
- if (failed (checkNotNone (rewriter, op, weight)) ||
1298
- failed (checkNotNone (rewriter, op, bias)))
1299
- return failure ();
1300
-
1301
- auto inputType = input.getType ().cast <RankedTensorType>();
1302
- auto weightType = weight.getType ().cast <RankedTensorType>();
1303
- auto biasType = bias.getType ().cast <RankedTensorType>();
1304
- int64_t inputRank = inputType.getRank ();
1305
- Type elemTy = inputType.getElementType ();
1306
-
1307
- // Step 1. Check if all the arguments meet the requirements.
1308
- SmallVector<Value> normalizedShapeSizesTorchInt;
1309
- if (!getListConstructElements (normalizedShape,
1310
- normalizedShapeSizesTorchInt)) {
1311
- return rewriter.notifyMatchFailure (op,
1312
- " Unimplemented normalized_shape not"
1313
- " constructed from ListConstruct" );
1314
- }
1315
- SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues (
1316
- rewriter, loc, getTypeConverter (), normalizedShapeSizesTorchInt);
1317
- int64_t normalizedShapeRank = normalizedShapeSizesInt.size ();
1318
- if (weightType.getRank () != normalizedShapeRank ||
1319
- biasType.getRank () != normalizedShapeRank ||
1320
- inputRank < normalizedShapeRank || normalizedShapeRank < 1 )
1321
- return rewriter.notifyMatchFailure (op, " Input or weight or bias shape or"
1322
- " normalized shape not compatible" );
1323
-
1324
- // Check all the dimensions match the normalized_shape
1325
- int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size ();
1326
- for (auto en : enumerate((normalizedShapeSizesInt))) {
1327
- auto index = en.index ();
1328
- auto inputDim =
1329
- getDimOp (rewriter, loc, input, index + meanAndVarShapeRank);
1330
- auto weightDim = getDimOp (rewriter, loc, weight, index );
1331
- auto biasDim = getDimOp (rewriter, loc, bias, index );
1332
-
1333
- auto expectedSize = en.value ();
1334
- checkDimEqualHelper (rewriter, loc, inputDim, expectedSize);
1335
- checkDimEqualHelper (rewriter, loc, weightDim, expectedSize);
1336
- checkDimEqualHelper (rewriter, loc, biasDim, expectedSize);
1337
- }
1338
-
1339
- // Get iterator types for input shape.
1340
- SmallVector<StringRef> normalizedShapeIteratorTypes (
1341
- normalizedShapeRank, getReductionIteratorTypeName ());
1342
- SmallVector<StringRef> meanAndVarIterationTypes (
1343
- meanAndVarShapeRank, getParallelIteratorTypeName ());
1344
- SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
1345
- inputShapeIteratorTypes.append (normalizedShapeIteratorTypes);
1346
-
1347
- // Step 2. Common parts to be used for getting mean and var.
1348
-
1349
- // Get sizes and affineMaps needed for mean and var.
1350
- AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap (inputRank);
1351
- SmallVector<AffineExpr> meanAndVarShapeExprs;
1352
- for (int i = 0 ; i < meanAndVarShapeRank; i++)
1353
- meanAndVarShapeExprs.push_back (mlir::getAffineDimExpr (i, context));
1354
- auto meanAndVarShapeAffineMap = AffineMap::get (
1355
- /* dimCount=*/ inputRank,
1356
- /* symbolCount=*/ 0 , meanAndVarShapeExprs, context);
1357
- SmallVector<Value> meanAndVarShapeSizes =
1358
- getTensorSizesUntilDim (rewriter, loc, input, meanAndVarShapeRank - 1 );
1359
-
1360
- // Get number of elements to be used for calculating mean and var.
1361
- Value elemCnts = normalizedShapeSizesInt[0 ];
1362
- for (int i = 1 ; i < normalizedShapeRank; i++) {
1363
- elemCnts = rewriter.create <arith::MulIOp>(loc, elemCnts,
1364
- normalizedShapeSizesInt[i]);
1365
- }
1366
- Value elemCntsFloat =
1367
- rewriter.create <arith::SIToFPOp>(loc, elemTy, elemCnts);
1368
-
1369
- // Helper to calculate mean and var.
1370
- auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
1371
- SmallVector<AffineMap> indexingMaps (
1372
- 2 , rewriter.getMultiDimIdentityMap (meanAndVarShapeRank));
1373
- Value initShapeTensor = rewriter.create <linalg::InitTensorOp>(
1374
- loc, meanAndVarShapeSizes, elemTy);
1375
- return rewriter
1376
- .create <linalg::GenericOp>(
1377
- loc, initShapeTensor.getType (), sumOrSquareSum, initShapeTensor,
1378
- /* indexingMaps=*/ indexingMaps,
1379
- /* iteratorTypes=*/ meanAndVarIterationTypes,
1380
- [&](OpBuilder &b, Location loc, ValueRange args) {
1381
- Value sumOrSqureSum = args[0 ];
1382
- Value result =
1383
- b.create <arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
1384
- b.create <linalg::YieldOp>(loc, result);
1385
- })
1386
- .getResult (0 );
1387
- };
1388
-
1389
- // Step 3. Get mean.
1390
-
1391
- // Get sum to be used for calculating mean.
1392
- SmallVector<AffineMap, 2 > sumIndexingMaps = {
1393
- inputShapeAffineMap, // input
1394
- meanAndVarShapeAffineMap, // output
1395
- };
1396
- auto initSumTensor =
1397
- createZeroInitTensor (rewriter, loc, meanAndVarShapeSizes, elemTy);
1398
- Value sum = rewriter
1399
- .create <linalg::GenericOp>(
1400
- loc, initSumTensor.getType (), input, initSumTensor,
1401
- /* indexingMaps=*/ sumIndexingMaps,
1402
- /* iteratorTypes=*/ inputShapeIteratorTypes,
1403
- [&](OpBuilder &b, Location loc, ValueRange args) {
1404
- Value input = args[0 ], sum = args[1 ];
1405
- Value result =
1406
- rewriter.create <arith::AddFOp>(loc, sum, input);
1407
- b.create <linalg::YieldOp>(loc, result);
1408
- })
1409
- .getResult (0 );
1410
- Value mean = genMeanOrVarCalculation (sum);
1411
-
1412
- // Step 4. Get rSTD.
1413
-
1414
- // Calculate squareSum for the layer.
1415
- SmallVector<AffineMap> squareSumIndexingMaps{
1416
- inputShapeAffineMap,
1417
- meanAndVarShapeAffineMap,
1418
- meanAndVarShapeAffineMap,
1419
- };
1420
- auto initSquareSumTensor =
1421
- createZeroInitTensor (rewriter, loc, meanAndVarShapeSizes, elemTy);
1422
- Value squareSum =
1423
- rewriter
1424
- .create <linalg::GenericOp>(
1425
- loc, initSquareSumTensor.getType (), ValueRange{input, mean},
1426
- initSquareSumTensor,
1427
- /* indexingMaps=*/ squareSumIndexingMaps,
1428
- /* iteratorTypes=*/ inputShapeIteratorTypes,
1429
- [&](OpBuilder &b, Location loc, ValueRange args) {
1430
- Value input = args[0 ], mean = args[1 ], squareSum = args[2 ];
1431
- Value sub = rewriter.create <arith::SubFOp>(loc, input, mean);
1432
- Value square = rewriter.create <arith::MulFOp>(loc, sub, sub);
1433
- Value result =
1434
- rewriter.create <arith::AddFOp>(loc, squareSum, square);
1435
- b.create <linalg::YieldOp>(loc, result);
1436
- })
1437
- .getResult (0 );
1438
- Value var = genMeanOrVarCalculation (squareSum);
1439
- Value rSTDTensor = rewriter.create <linalg::InitTensorOp>(
1440
- loc, meanAndVarShapeSizes, elemTy);
1441
- SmallVector<AffineMap> rSTDIndexingMap (
1442
- 2 , rewriter.getMultiDimIdentityMap (meanAndVarShapeRank));
1443
-
1444
- Value rSTD = rewriter
1445
- .create <linalg::GenericOp>(
1446
- loc, rSTDTensor.getType (), var, rSTDTensor,
1447
- rSTDIndexingMap, meanAndVarIterationTypes,
1448
- [&](OpBuilder &b, Location loc, ValueRange args) {
1449
- Value result =
1450
- calculateRSTD (b, loc, elemTy, eps, args[0 ]);
1451
- b.create <linalg::YieldOp>(loc, result);
1452
- })
1453
- .getResult (0 );
1454
-
1455
- // Step 5. Get layernorm.
1456
-
1457
- // Get affineMap for normalized shape.
1458
- SmallVector<AffineExpr> normalizedShapeExprs;
1459
- for (int i = meanAndVarShapeRank; i < inputRank; i++)
1460
- normalizedShapeExprs.push_back (mlir::getAffineDimExpr (i, context));
1461
- auto normalizedShapeAffineMap = AffineMap::get (
1462
- /* dimCount=*/ inputRank,
1463
- /* symbolCount=*/ 0 , normalizedShapeExprs, context);
1464
- auto inputSizes = getTensorSizes (rewriter, loc, input);
1465
- Value initLayerNormTensor =
1466
- rewriter.create <linalg::InitTensorOp>(loc, inputSizes, elemTy);
1467
- SmallVector<AffineMap> indexingMaps (1 , inputShapeAffineMap);
1468
- indexingMaps.resize (3 , meanAndVarShapeAffineMap);
1469
- indexingMaps.resize (5 , normalizedShapeAffineMap);
1470
- indexingMaps.push_back (inputShapeAffineMap);
1471
- SmallVector<StringRef> layerNormIterationTypes (
1472
- inputRank, getParallelIteratorTypeName ());
1473
- Value layerNorm =
1474
- rewriter
1475
- .create <linalg::GenericOp>(
1476
- loc, initLayerNormTensor.getType (),
1477
- ValueRange{input, mean, rSTD, weight, bias},
1478
- initLayerNormTensor,
1479
- /* indexingMaps=*/ indexingMaps,
1480
- /* iteratorTypes=*/ layerNormIterationTypes,
1481
- [&](OpBuilder &b, Location loc, ValueRange args) {
1482
- Value input = args[0 ], mean = args[1 ], rSTD = args[2 ],
1483
- weight = args[3 ], bias = args[4 ];
1484
- Value result =
1485
- createLinalgPayloadCalculationForNormOpsWithRSTD (
1486
- b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
1487
- b.create <linalg::YieldOp>(loc, result);
1488
- })
1489
- .getResult (0 );
1490
- SmallVector<int64_t > expandShape (inputRank, 1 );
1491
- for (int i = 0 ; i < meanAndVarShapeRank; i++) {
1492
- // `mean` and `rstd` are not yet casted, so they will be having dynamic
1493
- // shape. Hence to match them, for each dimension corresponding to `mean`
1494
- // or `rstd` assign -1.
1495
- expandShape[i] = -1 ;
1496
- }
1497
- auto expandShapeType = RankedTensorType::get (expandShape, elemTy);
1498
- SmallVector<ReassociationIndices> reassociation (meanAndVarShapeRank);
1499
- for (auto i : llvm::seq<int64_t >(0 , meanAndVarShapeRank)) {
1500
- reassociation[i].push_back (i);
1501
- if (i == meanAndVarShapeRank - 1 ) {
1502
- for (auto j : llvm::seq<int64_t >(0 , normalizedShapeRank))
1503
- reassociation[i].push_back (i + j + 1 );
1504
- }
1505
- }
1506
- Value meanResult = rewriter.create <tensor::ExpandShapeOp>(
1507
- loc, expandShapeType, mean, reassociation);
1508
- Value rSTDResult = rewriter.create <tensor::ExpandShapeOp>(
1509
- loc, expandShapeType, rSTD, reassociation);
1510
- Type layerNormResultType = getTypeConverter ()->convertType (op.getType (0 ));
1511
- Type meanResultType = getTypeConverter ()->convertType (op.getType (1 ));
1512
- Type rSTDResultType = getTypeConverter ()->convertType (op.getType (2 ));
1513
- Value layerNorm_ =
1514
- rewriter.create <tensor::CastOp>(loc, layerNormResultType, layerNorm);
1515
- Value mean_ =
1516
- rewriter.create <tensor::CastOp>(loc, meanResultType, meanResult);
1517
- Value var_ =
1518
- rewriter.create <tensor::CastOp>(loc, rSTDResultType, rSTDResult);
1519
- rewriter.replaceOp (op, {layerNorm_, mean_, var_});
1520
- return success ();
1521
- }
1522
- };
1523
- } // namespace
1524
-
1525
1260
namespace {
1526
1261
class ConvertAtenNllLossBackwardOp
1527
1262
: public OpConversionPattern<AtenNllLossBackwardOp> {
@@ -1728,8 +1463,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
1728
1463
patterns.add <ConvertAtenNllLossForwardOp>(typeConverter, context);
1729
1464
target.addIllegalOp <AtenBatchNormOp>();
1730
1465
patterns.add <ConvertAtenBatchNormOp>(typeConverter, context);
1731
- target.addIllegalOp <AtenNativeLayerNormOp>();
1732
- patterns.add <ConvertAtenNativeLayerNormOp>(typeConverter, context);
1733
1466
target.addIllegalOp <AtenNllLossBackwardOp>();
1734
1467
patterns.add <ConvertAtenNllLossBackwardOp>(typeConverter, context);
1735
1468
patterns.add <ConvertTensorStaticInfoCastOp>(typeConverter, context);
0 commit comments