Skip to content

Commit e7bcf3e

Browse files
Lancernlanza
authored andcommitted
[CIR][NFC] Fix bug during fp16 unary op CIRGen (#706)
This PR fixes a bug during the CIRGen of fp16 unary operations. Before this patch, for the expression `-x` where `x` is a fp16 value, CIRGen emits the code like the following: ```mlir %0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float %1 = cir.cast float_to_float %0 : !cir.float -> !cir.f16 %2 = cir.unary minus %1 : !cir.fp16 ``` The expected CIRGen should instead be: ```mlir %0 = cir.cast float_to_float %x : !cir.f16 -> !cir.float %1 = cir.unary minus %0 : !cir.float %2 = cir.cast float_to_float %1 : !cir.float -> !cir.f16 ``` This PR fixes this issue.
1 parent 085a96e commit e7bcf3e

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+26-16
Original file line numberDiff line numberDiff line change
@@ -604,37 +604,47 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
604604
: PromotionType;
605605
auto result = VisitPlus(E, promotionTy);
606606
if (result && !promotionTy.isNull())
607-
result = buildUnPromotedValue(result, E->getType());
608-
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result);
607+
return buildUnPromotedValue(result, E->getType());
608+
return result;
609609
}
610610

611-
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
611+
mlir::Value VisitPlus(const UnaryOperator *E,
612+
QualType PromotionType = QualType()) {
612613
// This differs from gcc, though, most likely due to a bug in gcc.
613614
TestAndClearIgnoreResultAssign();
615+
616+
mlir::Value operand;
614617
if (!PromotionType.isNull())
615-
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
616-
return Visit(E->getSubExpr());
618+
operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
619+
else
620+
operand = Visit(E->getSubExpr());
621+
622+
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, operand);
617623
}
618624

619-
mlir::Value VisitUnaryMinus(const UnaryOperator *E) {
620-
// NOTE(cir): QualType function parameter still not used, so don´t replicate
621-
// it here yet.
622-
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
625+
mlir::Value VisitUnaryMinus(const UnaryOperator *E,
626+
QualType PromotionType = QualType()) {
627+
QualType promotionTy = PromotionType.isNull()
628+
? getPromotionType(E->getSubExpr()->getType())
629+
: PromotionType;
623630
auto result = VisitMinus(E, promotionTy);
624631
if (result && !promotionTy.isNull())
625-
result = buildUnPromotedValue(result, E->getType());
626-
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result);
632+
return buildUnPromotedValue(result, E->getType());
633+
return result;
627634
}
628635

629636
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
630637
TestAndClearIgnoreResultAssign();
638+
639+
mlir::Value operand;
631640
if (!PromotionType.isNull())
632-
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
641+
operand = CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
642+
else
643+
operand = Visit(E->getSubExpr());
633644

634645
// NOTE: LLVM codegen will lower this directly to either a FNeg
635646
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.
636-
637-
return Visit(E->getSubExpr());
647+
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, operand);
638648
}
639649

640650
mlir::Value VisitUnaryNot(const UnaryOperator *E) {
@@ -660,8 +670,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
660670
mlir::Value buildUnaryOp(const UnaryOperator *E, mlir::cir::UnaryOpKind kind,
661671
mlir::Value input) {
662672
return Builder.create<mlir::cir::UnaryOp>(
663-
CGF.getLoc(E->getSourceRange().getBegin()),
664-
CGF.getCIRType(E->getType()), kind, input);
673+
CGF.getLoc(E->getSourceRange().getBegin()), input.getType(), kind,
674+
input);
665675
}
666676

667677
// C++

clang/test/CIR/CodeGen/bf16-ops.c

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ void foo(void) {
3030

3131
h1 = -h1;
3232
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
33-
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16
34-
// NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.bf16, !cir.bf16
33+
// NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float
34+
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16
3535

3636
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
3737
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16
3838
// NATIVE: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.bf16, !cir.bf16
3939

4040
h1 = +h1;
4141
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
42-
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.bf16
43-
// NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.bf16, !cir.bf16
42+
// NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float
43+
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.bf16
4444

4545
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
4646
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.bf16

clang/test/CIR/CodeGen/fp16-ops.c

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ void foo(void) {
3030

3131
h1 = -h1;
3232
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
33-
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16
34-
// NONATIVE-NEXT: %{{.+}} = cir.unary(minus, %[[#B]]) : !cir.f16, !cir.f16
33+
// NONATIVE-NEXT: %[[#B:]] = cir.unary(minus, %[[#A]]) : !cir.float, !cir.float
34+
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16
3535

3636
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
3737
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16
3838
// NATIVE: %{{.+}} = cir.unary(minus, %{{.+}}) : !cir.f16, !cir.f16
3939

4040
h1 = +h1;
4141
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
42-
// NONATIVE-NEXT: %[[#B:]] = cir.cast(floating, %[[#A]] : !cir.float), !cir.f16
43-
// NONATIVE-NEXT: %{{.+}} = cir.unary(plus, %[[#B]]) : !cir.f16, !cir.f16
42+
// NONATIVE-NEXT: %[[#B:]] = cir.unary(plus, %[[#A]]) : !cir.float, !cir.float
43+
// NONATIVE-NEXT: %{{.+}} = cir.cast(floating, %[[#B]] : !cir.float), !cir.f16
4444

4545
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.f16), !cir.float
4646
// NATIVE-NOT: %{{.+}} = cir.cast(floating, %{{.+}} : !cir.float), !cir.f16

0 commit comments

Comments
 (0)