Skip to content

Commit 0b23af2

Browse files
author
Tanyo Kwok
authored
[MHLO] support non-constant torch scalar in BasicOps (#1134)
See RFC #999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
1 parent 82af44d commit 0b23af2

File tree

5 files changed

+362
-350
lines changed

5 files changed

+362
-350
lines changed

lib/Conversion/TorchToMhlo/BasicOp.cpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,15 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
159159
}
160160

161161
if (!rhsType) {
162-
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
163-
outElemTy, {})))
164-
return op.emitError("currently only scalar constants are supported for "
165-
"conversion in MHLO operation");
162+
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
166163
}
167164

168165
lhs = mhlo::promoteType(rewriter, lhs, outType);
169166
rhs = mhlo::promoteType(rewriter, rhs, outType);
170167

171168
if (!skipMultiplyAlpha(op.alpha())) {
172-
Value alpha;
173-
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
174-
op.alpha(), alpha, outElemTy, {},
175-
/*checkForUnity=*/false))) {
176-
return op.emitError("currently only scalar constants are supported for "
177-
"alpha in conversion to MHLO operation");
178-
}
169+
Value alpha =
170+
mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy);
179171
DenseIntElementsAttr bcastDimensions;
180172
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
181173
bcastDimensions);
@@ -216,13 +208,13 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
216208
return op.emitError(
217209
"only floating-point or integer datatype legalization supported");
218210
}
219-
if (!rhsType) {
220-
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
221-
outElemTy, {})))
222-
return op.emitError("currently only scalar constants are supported for "
223-
"conversion in MHLO operation");
224-
}
225211

212+
Value lhsTensor = lhs;
213+
if (std::is_same<AtenOpT, AtenSquareOp>()) {
214+
rhs = lhs;
215+
} else if (!rhsType) {
216+
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
217+
}
226218
DenseIntElementsAttr bcastDimensions;
227219
lhs = mhlo::promoteType(rewriter, lhs, outType);
228220
rhs = mhlo::promoteType(rewriter, rhs, outType);
@@ -263,11 +255,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
263255
}
264256

265257
if (!rhsTy) {
266-
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
267-
lhsElemTy, {}))) {
268-
return op.emitError("currently only scalar constants are supported for "
269-
"conversion in MHLO operation");
270-
}
258+
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy);
271259
}
272260

273261
// TODO: what is the PyTorch default type promotion?
@@ -569,12 +557,8 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
569557
.cast<RankedTensorType>();
570558
auto outputShape = outputType.getShape();
571559
auto outputElemType = outputType.getElementType();
572-
Value mhloTensor;
573-
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
574-
outputElemType, outputShape,
575-
false))) {
576-
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
577-
}
560+
Value mhloTensor =
561+
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
578562
rewriter.replaceOp(op, mhloTensor);
579563
return success();
580564
}
@@ -1020,4 +1004,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
10201004
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
10211005
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
10221006
#undef INSERT_ATENOP_PATTERN
1023-
}
1007+
}

lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp

Lines changed: 10 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
174174
return const_op.getResult();
175175
}
176176

177-
// TODO: Support for variable scalar.
178-
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
179-
Operation *op, Value torchScalarValue,
180-
Value &mhloTensor, Type dtype,
181-
llvm::ArrayRef<int64_t> dshape,
182-
bool doBroadcast) {
183-
// Retrieve a const float or int value but create the out Tensor with dtype.
184-
double doubleValue;
185-
auto isFloat =
186-
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
187-
188-
int64_t intValue;
189-
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
190-
191-
if (!isFloat && !isInt)
192-
return op->emitError("Unable to extract the scalar constant");
193-
194-
if (dtype.isa<mlir::FloatType>()) {
195-
if (doBroadcast) {
196-
mhloTensor = getSplatConstTensor<float>(
197-
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
198-
} else {
199-
mhloTensor = mhlo::getConstTensor<float>(
200-
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
201-
.getValue();
202-
}
203-
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
204-
auto w = intType.getWidth();
205-
if (w != 32 && w != 64)
206-
return op->emitError("Unsupported integer type") << intType;
207-
208-
if (w == 32) {
209-
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
210-
return op->emitError("Supplied value of scalar constant exceeds limits "
211-
"of destination type");
212-
}
213-
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
214-
: static_cast<int32_t>(intValue);
215-
if (doBroadcast) {
216-
mhloTensor =
217-
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
218-
} else {
219-
mhloTensor =
220-
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
221-
}
222-
} else if (w == 64) {
223-
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
224-
return op->emitError("Supplied value of scalar constant exceeds limits "
225-
"of destination type");
226-
}
227-
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
228-
if (doBroadcast) {
229-
mhloTensor =
230-
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
231-
} else {
232-
mhloTensor =
233-
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
234-
}
235-
}
236-
} else
237-
return op->emitError("Usupported element type");
238-
239-
return success();
240-
}
241-
242-
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
243-
Operation *op, Value alphaScalar,
244-
Value &alphaTensor, Type dtype,
245-
llvm::ArrayRef<int64_t> dshape,
246-
bool checkForUnity) {
247-
if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor,
248-
dtype, dshape)))
249-
return success();
250-
251-
// `alpha` has not been specified.
252-
int64_t alphaValue;
253-
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
254-
return op->emitError("Currently only scalar constants are supported for "
255-
"alpha in MHLO operation");
256-
// When no alpha has been specified, this must be 1.
257-
if (checkForUnity && alphaValue != 1)
258-
return op->emitError("Unsupported integer value for alpha");
259-
260-
alphaTensor =
261-
mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue);
262-
263-
return success();
177+
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
178+
Value scalarValue, Type dtype) {
179+
auto tensor = rewriter.create<tensor::FromElementsOp>(
180+
op->getLoc(), ArrayRef<Value>{scalarValue});
181+
auto dtype_tensor =
182+
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
183+
return rewriter.create<mhlo::ReshapeOp>(
184+
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
185+
dtype_tensor);
264186
}
265187

266188
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
@@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
439361
.getResult();
440362
}
441363
} // namespace mhlo
442-
} // namespace mlir
364+
} // namespace mlir

lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,8 @@ template <typename T>
4747
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
4848
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
4949

50-
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
51-
Operation *op, Value torchScalarValue,
52-
Value &mhloTensor, Type dtype,
53-
llvm::ArrayRef<int64_t> dshape,
54-
bool doBroadcast = true);
55-
56-
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
57-
Operation *op, Value alphaScalar,
58-
Value &alphaTensor, Type dtype,
59-
llvm::ArrayRef<int64_t> dshape,
60-
bool checkForUnity);
50+
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
51+
Value scalarValue, Type dtype);
6152

6253
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
6354

test/Conversion/TorchToMhlo/basic.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
4141

4242
// -----
4343

44-
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> {
45-
// CHECK: %int1 = torch.constant.int 1
46-
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
47-
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
48-
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
44+
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
45+
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
46+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
47+
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
48+
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
49+
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
50+
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
51+
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
52+
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
4953
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
5054
%int1 = torch.constant.int 1
5155
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>
@@ -251,4 +255,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
251255
%2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
252256
%result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list<int>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32>
253257
return %result0 : !torch.vtensor<[3,7,4,5],f32>
254-
}
258+
}

0 commit comments

Comments
 (0)