@@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
174
174
return const_op.getResult ();
175
175
}
176
176
177
- // TODO: Support for variable scalar.
178
- LogicalResult torchScalarToMhloTensor (ConversionPatternRewriter &rewriter,
179
- Operation *op, Value torchScalarValue,
180
- Value &mhloTensor, Type dtype,
181
- llvm::ArrayRef<int64_t > dshape,
182
- bool doBroadcast) {
183
- // Retrieve a const float or int value but create the out Tensor with dtype.
184
- double doubleValue;
185
- auto isFloat =
186
- matchPattern (torchScalarValue, m_TorchConstantFloat (&doubleValue));
187
-
188
- int64_t intValue;
189
- auto isInt = matchPattern (torchScalarValue, m_TorchConstantInt (&intValue));
190
-
191
- if (!isFloat && !isInt)
192
- return op->emitError (" Unable to extract the scalar constant" );
193
-
194
- if (dtype.isa <mlir::FloatType>()) {
195
- if (doBroadcast) {
196
- mhloTensor = getSplatConstTensor<float >(
197
- rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
198
- } else {
199
- mhloTensor = mhlo::getConstTensor<float >(
200
- rewriter, op, (isFloat ? doubleValue : intValue), dshape)
201
- .getValue ();
202
- }
203
- } else if (auto intType = dtype.dyn_cast <mlir::IntegerType>()) {
204
- auto w = intType.getWidth ();
205
- if (w != 32 && w != 64 )
206
- return op->emitError (" Unsupported integer type" ) << intType;
207
-
208
- if (w == 32 ) {
209
- if (!isInValidRange<int32_t >(isFloat, doubleValue, isInt, intValue)) {
210
- return op->emitError (" Supplied value of scalar constant exceeds limits "
211
- " of destination type" );
212
- }
213
- int32_t d = isFloat ? static_cast <int32_t >(doubleValue)
214
- : static_cast <int32_t >(intValue);
215
- if (doBroadcast) {
216
- mhloTensor =
217
- getSplatConstTensor<int32_t >(rewriter, op, d, dtype, dshape);
218
- } else {
219
- mhloTensor =
220
- mhlo::getConstTensor<int32_t >(rewriter, op, {d}, dshape).getValue ();
221
- }
222
- } else if (w == 64 ) {
223
- if (!isInValidRange<int64_t >(isFloat, doubleValue, isInt, intValue)) {
224
- return op->emitError (" Supplied value of scalar constant exceeds limits "
225
- " of destination type" );
226
- }
227
- int64_t d = (isFloat ? static_cast <int64_t >(doubleValue) : intValue);
228
- if (doBroadcast) {
229
- mhloTensor =
230
- getSplatConstTensor<int64_t >(rewriter, op, d, dtype, dshape);
231
- } else {
232
- mhloTensor =
233
- mhlo::getConstTensor<int64_t >(rewriter, op, {d}, dshape).getValue ();
234
- }
235
- }
236
- } else
237
- return op->emitError (" Usupported element type" );
238
-
239
- return success ();
240
- }
241
-
242
- LogicalResult torchAlphaToMhloTensor (ConversionPatternRewriter &rewriter,
243
- Operation *op, Value alphaScalar,
244
- Value &alphaTensor, Type dtype,
245
- llvm::ArrayRef<int64_t > dshape,
246
- bool checkForUnity) {
247
- if (succeeded (torchScalarToMhloTensor (rewriter, op, alphaScalar, alphaTensor,
248
- dtype, dshape)))
249
- return success ();
250
-
251
- // `alpha` has not been specified.
252
- int64_t alphaValue;
253
- if (!matchPattern (alphaScalar, m_TorchConstantInt (&alphaValue)))
254
- return op->emitError (" Currently only scalar constants are supported for "
255
- " alpha in MHLO operation" );
256
- // When no alpha has been specified, this must be 1.
257
- if (checkForUnity && alphaValue != 1 )
258
- return op->emitError (" Unsupported integer value for alpha" );
259
-
260
- alphaTensor =
261
- mlir::mhlo::getMhloConstTensorSingleF32 (rewriter, op, alphaValue);
262
-
263
- return success ();
177
+ Value scalarToMhloTensor (ConversionPatternRewriter &rewriter, Operation *op,
178
+ Value scalarValue, Type dtype) {
179
+ auto tensor = rewriter.create <tensor::FromElementsOp>(
180
+ op->getLoc (), ArrayRef<Value>{scalarValue});
181
+ auto dtype_tensor =
182
+ rewriter.create <mhlo::ConvertOp>(op->getLoc (), tensor, dtype);
183
+ return rewriter.create <mhlo::ReshapeOp>(
184
+ op->getLoc (), RankedTensorType::get (mlir::ArrayRef<int64_t >{}, dtype),
185
+ dtype_tensor);
264
186
}
265
187
266
188
Value promoteType (PatternRewriter &rewriter, Value input, TensorType outType) {
@@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
439
361
.getResult ();
440
362
}
441
363
} // namespace mhlo
442
- } // namespace mlir
364
+ } // namespace mlir
0 commit comments