Skip to content

Commit c6fe12b

Browse files
authored
[SPIR-V Extension] fpbuiltin-max-error support (#2056)
This changes add SPIR-V translator support for the SPIR-V extension documented here: KhronosGroup/SPIRV-Registry#193. This extension adds one decoration to represent maximum error for FP operations and adds the related Capability. SPIRV Headers support for representing this in SPIR-V: KhronosGroup/SPIRV-Headers#363 intel/llvm#8134 added a new call-site attribute associated with FP builtin intrinsics. This attribute is named 'fpbuiltin-max-error'. Following example shows how this extension is supported in the translator. The input LLVM IR uses new LLVM builtin calls to represent FP operations. An attribute named 'fpbuiltin-max-error' is used to represent the max-error allowed in the FP operation. Example Input LLVM: %t6 = call float @llvm.fpbuiltin.sin.f32(float %f1) #2 attributes #2 = { "fpbuiltin-max-error"="2.5" } This is translated into a SPIR-V instruction (for add/sub/mul/div/rem) and OpenCl extended instruction for other instructions. A decoration to represent the max-error is attached to the SPIR-V instruction. SPIR-V code: 4 Decorate 97 FPMaxErrorDecorationINTEL 1075838976 6 ExtInst 2 97 1 sin 88 No new support is added to support translating this SPIR_V back to LLVM. Existing support is used. The decoration is translated back into named metadata associated with the LLVM instruction. This can be readily consumed by backends. Based on input from @andykaylor, we emit attributes when the FP operation is translated back to a call to a builtin function and emit metadata otherwise. Translated LLVM code for basic math functions (add/sub/mul/div/rem): %t6 = fmul float %f1, %f2, !fpbuiltin-max-error !7 !7 = !{!"2.500000"} Translated LLVM code for other math functions: %t6 = call spir_func float @_Z3sinf(float %f1) #3 attributes #3 = { "fpbuiltin-max-error"="4.000000" } Signed-off-by: Arvind Sudarsanam <[email protected]>
1 parent 613f536 commit c6fe12b

File tree

10 files changed

+464
-2
lines changed

10 files changed

+464
-2
lines changed

include/LLVMSPIRVExtensions.inc

+1
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,5 @@ EXT(SPV_INTEL_tensor_float32_rounding)
6363
EXT(SPV_EXT_relaxed_printf_string_address_space)
6464
EXT(SPV_INTEL_fpga_argument_interfaces)
6565
EXT(SPV_INTEL_fpga_latency_control)
66+
EXT(SPV_INTEL_fp_max_error)
6667
EXT(SPV_INTEL_cache_controls)

lib/SPIRV/SPIRVBuiltinHelper.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ Value *BuiltinCallMutator::doConversion() {
102102
CallInst *NewCall =
103103
Builder.Insert(addCallInst(CI->getModule(), FuncName, ReturnTy, Args,
104104
&Attrs, nullptr, Mangler.get()));
105+
NewCall->copyMetadata(*CI);
106+
NewCall->setAttributes(CI->getAttributes());
105107
Value *Result = MutateRet ? MutateRet(Builder, NewCall) : NewCall;
106108
Result->takeName(CI);
107109
if (!CI->getType()->isVoidTy())

lib/SPIRV/SPIRVReader.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -3893,7 +3893,48 @@ void SPIRVToLLVM::transDecorationsToMetadata(SPIRVValue *BV, Value *V) {
38933893
SetDecorationsMetadata(I);
38943894
}
38953895

3896+
namespace {
3897+
3898+
static float convertSPIRVWordToFloat(SPIRVWord Spir) {
3899+
union {
3900+
float F;
3901+
SPIRVWord Spir;
3902+
} FPMaxError;
3903+
FPMaxError.Spir = Spir;
3904+
return FPMaxError.F;
3905+
}
3906+
3907+
static bool transFPMaxErrorDecoration(SPIRVValue *BV, Value *V,
3908+
LLVMContext *Context) {
3909+
SPIRVWord ID;
3910+
if (Instruction *I = dyn_cast<Instruction>(V))
3911+
if (BV->hasDecorate(DecorationFPMaxErrorDecorationINTEL, 0, &ID)) {
3912+
auto Literals =
3913+
BV->getDecorationLiterals(DecorationFPMaxErrorDecorationINTEL);
3914+
assert(Literals.size() == 1 &&
3915+
"FP Max Error decoration shall have 1 operand");
3916+
auto F = convertSPIRVWordToFloat(Literals[0]);
3917+
if (CallInst *CI = dyn_cast<CallInst>(I)) {
3918+
// Add attribute
3919+
auto A = llvm::Attribute::get(*Context, "fpbuiltin-max-error",
3920+
std::to_string(F));
3921+
CI->addFnAttr(A);
3922+
} else {
3923+
// Add metadata
3924+
MDNode *N =
3925+
MDNode::get(*Context, MDString::get(*Context, std::to_string(F)));
3926+
I->setMetadata("fpbuiltin-max-error", N);
3927+
}
3928+
return true;
3929+
}
3930+
return false;
3931+
}
3932+
} // namespace
3933+
38963934
bool SPIRVToLLVM::transDecoration(SPIRVValue *BV, Value *V) {
3935+
if (transFPMaxErrorDecoration(BV, V, Context))
3936+
return true;
3937+
38973938
if (!transAlign(BV, V))
38983939
return false;
38993940

lib/SPIRV/SPIRVWriter.cpp

+153-1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ using namespace llvm;
106106
using namespace SPIRV;
107107
using namespace OCLUtil;
108108

109+
namespace {
110+
111+
static SPIRVWord convertFloatToSPIRVWord(float F) {
112+
union {
113+
float F;
114+
SPIRVWord Spir;
115+
} FPMaxError;
116+
FPMaxError.F = F;
117+
return FPMaxError.Spir;
118+
}
119+
120+
} // namespace
121+
109122
namespace SPIRV {
110123

111124
static void foreachKernelArgMD(
@@ -3481,6 +3494,26 @@ bool LLVMToSPIRVBase::isKnownIntrinsic(Intrinsic::ID Id) {
34813494
}
34823495
}
34833496

3497+
// Add decoration if needed
3498+
SPIRVInstruction *addFPBuiltinDecoration(SPIRVModule *BM, IntrinsicInst *II,
3499+
SPIRVInstruction *I) {
3500+
const bool AllowFPMaxError =
3501+
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_max_error);
3502+
assert(II->getCalledFunction()->getName().startswith("llvm.fpbuiltin"));
3503+
// Add a new decoration for llvm.builtin intrinsics, if needed
3504+
if (AllowFPMaxError)
3505+
if (II->getAttributes().hasFnAttr("fpbuiltin-max-error")) {
3506+
double F = 0.0;
3507+
II->getAttributes()
3508+
.getFnAttr("fpbuiltin-max-error")
3509+
.getValueAsString()
3510+
.getAsDouble(F);
3511+
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
3512+
convertFloatToSPIRVWord(F));
3513+
}
3514+
return I;
3515+
}
3516+
34843517
// Performs mapping of LLVM IR rounding mode to SPIR-V rounding mode
34853518
// Value *V is metadata <rounding mode> argument of
34863519
// llvm.experimental.constrained.* intrinsics
@@ -4424,8 +4457,9 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
44244457
}
44254458
return Result;
44264459
}
4427-
44284460
default:
4461+
if (auto *BVar = transFPBuiltinIntrinsicInst(II, BB))
4462+
return BVar;
44294463
if (BM->isUnknownIntrinsicAllowed(II))
44304464
return BM->addCallInst(
44314465
transFunctionDecl(II->getCalledFunction()),
@@ -4441,6 +4475,124 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
44414475
return nullptr;
44424476
}
44434477

4478+
LLVMToSPIRVBase::FPBuiltinType
4479+
LLVMToSPIRVBase::getFPBuiltinType(IntrinsicInst *II, StringRef &OpName) {
4480+
StringRef Name = II->getCalledFunction()->getName();
4481+
if (!Name.startswith("llvm.fpbuiltin"))
4482+
return FPBuiltinType::UNKNOWN;
4483+
Name.consume_front("llvm.fpbuiltin.");
4484+
OpName = Name.split('.').first;
4485+
FPBuiltinType Type =
4486+
StringSwitch<FPBuiltinType>(OpName)
4487+
.Cases("fadd", "fsub", "fmul", "fdiv", "frem",
4488+
FPBuiltinType::REGULAR_MATH)
4489+
.Cases("sin", "cos", "tan", FPBuiltinType::EXT_1OPS)
4490+
.Cases("sinh", "cosh", "tanh", FPBuiltinType::EXT_1OPS)
4491+
.Cases("asin", "acos", "atan", FPBuiltinType::EXT_1OPS)
4492+
.Cases("asinh", "acosh", "atanh", FPBuiltinType::EXT_1OPS)
4493+
.Cases("exp", "exp2", "exp10", "expm1", FPBuiltinType::EXT_1OPS)
4494+
.Cases("log", "log2", "log10", "log1p", FPBuiltinType::EXT_1OPS)
4495+
.Cases("sqrt", "rsqrt", "erf", "erfc", FPBuiltinType::EXT_1OPS)
4496+
.Cases("atan2", "pow", "hypot", "ldexp", FPBuiltinType::EXT_2OPS)
4497+
.Case("sincos", FPBuiltinType::EXT_3OPS)
4498+
.Default(FPBuiltinType::UNKNOWN);
4499+
return Type;
4500+
}
4501+
4502+
SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
4503+
SPIRVBasicBlock *BB) {
4504+
StringRef OpName;
4505+
auto FPBuiltinTypeVal = getFPBuiltinType(II, OpName);
4506+
if (FPBuiltinTypeVal == FPBuiltinType::UNKNOWN)
4507+
return nullptr;
4508+
switch (FPBuiltinTypeVal) {
4509+
case FPBuiltinType::REGULAR_MATH: {
4510+
auto BinOp = StringSwitch<Op>(OpName)
4511+
.Case("fadd", OpFAdd)
4512+
.Case("fsub", OpFSub)
4513+
.Case("fmul", OpFMul)
4514+
.Case("fdiv", OpFDiv)
4515+
.Case("frem", OpFRem)
4516+
.Default(OpUndef);
4517+
auto *BI = BM->addBinaryInst(BinOp, transType(II->getType()),
4518+
transValue(II->getArgOperand(0), BB),
4519+
transValue(II->getArgOperand(1), BB), BB);
4520+
return addFPBuiltinDecoration(BM, II, BI);
4521+
}
4522+
case FPBuiltinType::EXT_1OPS: {
4523+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4524+
break;
4525+
SPIRVType *STy = transType(II->getType());
4526+
std::vector<SPIRVValue *> Ops(1, transValue(II->getArgOperand(0), BB));
4527+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4528+
.Case("sin", OpenCLLIB::Sin)
4529+
.Case("cos", OpenCLLIB::Cos)
4530+
.Case("tan", OpenCLLIB::Tan)
4531+
.Case("sinh", OpenCLLIB::Sinh)
4532+
.Case("cosh", OpenCLLIB::Cosh)
4533+
.Case("tanh", OpenCLLIB::Tanh)
4534+
.Case("asin", OpenCLLIB::Asin)
4535+
.Case("acos", OpenCLLIB::Acos)
4536+
.Case("atan", OpenCLLIB::Atan)
4537+
.Case("asinh", OpenCLLIB::Asinh)
4538+
.Case("acosh", OpenCLLIB::Acosh)
4539+
.Case("atanh", OpenCLLIB::Atanh)
4540+
.Case("exp", OpenCLLIB::Exp)
4541+
.Case("exp2", OpenCLLIB::Exp2)
4542+
.Case("exp10", OpenCLLIB::Exp10)
4543+
.Case("expm1", OpenCLLIB::Expm1)
4544+
.Case("log", OpenCLLIB::Log)
4545+
.Case("log2", OpenCLLIB::Log2)
4546+
.Case("log10", OpenCLLIB::Log10)
4547+
.Case("log1p", OpenCLLIB::Log1p)
4548+
.Case("sqrt", OpenCLLIB::Sqrt)
4549+
.Case("rsqrt", OpenCLLIB::Rsqrt)
4550+
.Case("erf", OpenCLLIB::Erf)
4551+
.Case("erfc", OpenCLLIB::Erfc)
4552+
.Default(SPIRVWORD_MAX);
4553+
assert(ExtOp != SPIRVWORD_MAX);
4554+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4555+
Ops, BB);
4556+
return addFPBuiltinDecoration(BM, II, BI);
4557+
}
4558+
case FPBuiltinType::EXT_2OPS: {
4559+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4560+
break;
4561+
SPIRVType *STy = transType(II->getType());
4562+
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
4563+
transValue(II->getArgOperand(1), BB)};
4564+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4565+
.Case("atan2", OpenCLLIB::Atan2)
4566+
.Case("hypot", OpenCLLIB::Hypot)
4567+
.Case("pow", OpenCLLIB::Pow)
4568+
.Case("ldexp", OpenCLLIB::Ldexp)
4569+
.Default(SPIRVWORD_MAX);
4570+
assert(ExtOp != SPIRVWORD_MAX);
4571+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4572+
Ops, BB);
4573+
return addFPBuiltinDecoration(BM, II, BI);
4574+
}
4575+
case FPBuiltinType::EXT_3OPS: {
4576+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4577+
break;
4578+
SPIRVType *STy = transType(II->getType());
4579+
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
4580+
transValue(II->getArgOperand(1), BB),
4581+
transValue(II->getArgOperand(2), BB)};
4582+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4583+
.Case("sincos", OpenCLLIB::Sincos)
4584+
.Default(SPIRVWORD_MAX);
4585+
assert(ExtOp != SPIRVWORD_MAX);
4586+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4587+
Ops, BB);
4588+
return addFPBuiltinDecoration(BM, II, BI);
4589+
}
4590+
default:
4591+
return nullptr;
4592+
}
4593+
return nullptr;
4594+
}
4595+
44444596
SPIRVValue *LLVMToSPIRVBase::transFenceInst(FenceInst *FI,
44454597
SPIRVBasicBlock *BB) {
44464598
SPIRVWord MemorySemantics;

lib/SPIRV/SPIRVWriter.h

+10
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ class LLVMToSPIRVBase : protected BuiltinCallHelper {
108108
bool transBuiltinSet();
109109
bool isKnownIntrinsic(Intrinsic::ID Id);
110110
SPIRVValue *transIntrinsicInst(IntrinsicInst *Intrinsic, SPIRVBasicBlock *BB);
111+
enum class FPBuiltinType {
112+
REGULAR_MATH,
113+
EXT_1OPS,
114+
EXT_2OPS,
115+
EXT_3OPS,
116+
UNKNOWN
117+
};
118+
FPBuiltinType getFPBuiltinType(IntrinsicInst *II, StringRef &);
119+
SPIRVValue *transFPBuiltinIntrinsicInst(IntrinsicInst *II,
120+
SPIRVBasicBlock *BB);
111121
SPIRVValue *transFenceInst(FenceInst *FI, SPIRVBasicBlock *BB);
112122
SPIRVValue *transCallInst(CallInst *Call, SPIRVBasicBlock *BB);
113123
SPIRVValue *transDirectCallInst(CallInst *Call, SPIRVBasicBlock *BB);

lib/SPIRV/libSPIRV/SPIRVDecorate.h

+2
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ class SPIRVDecorate : public SPIRVDecorateGeneric {
201201
case DecorationLatencyControlLabelINTEL:
202202
case DecorationLatencyControlConstraintINTEL:
203203
return ExtensionID::SPV_INTEL_fpga_latency_control;
204+
case DecorationFPMaxErrorDecorationINTEL:
205+
return ExtensionID::SPV_INTEL_fp_max_error;
204206
case internal::DecorationCacheControlLoadINTEL:
205207
case internal::DecorationCacheControlStoreINTEL:
206208
return ExtensionID::SPV_INTEL_cache_controls;

lib/SPIRV/libSPIRV/SPIRVEnum.h

+2
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,8 @@ template <> inline void SPIRVMap<Decoration, SPIRVCapVec>::init() {
497497
{CapabilityFPGALatencyControlINTEL});
498498
ADD_VEC_INIT(DecorationLatencyControlConstraintINTEL,
499499
{CapabilityFPGALatencyControlINTEL});
500+
ADD_VEC_INIT(DecorationFPMaxErrorDecorationINTEL,
501+
{CapabilityFPMaxErrorINTEL});
500502
}
501503

502504
template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

+2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ template <> inline void SPIRVMap<Decoration, std::string>::init() {
198198
add(DecorationStableKernelArgumentINTEL, "StableKernelArgumentINTEL");
199199
add(DecorationLatencyControlLabelINTEL, "LatencyControlLabelINTEL");
200200
add(DecorationLatencyControlConstraintINTEL, "LatencyControlConstraintINTEL");
201+
add(DecorationFPMaxErrorDecorationINTEL, "FPMaxErrorDecorationINTEL");
201202

202203
// From spirv_internal.hpp
203204
add(internal::DecorationCallableFunctionINTEL, "CallableFunctionINTEL");
@@ -623,6 +624,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
623624
add(CapabilityMax, "Max");
624625
add(CapabilityFPGAArgumentInterfacesINTEL, "FPGAArgumentInterfacesINTEL");
625626
add(CapabilityFPGALatencyControlINTEL, "FPGALatencyControlINTEL");
627+
add(CapabilityFPMaxErrorINTEL, "FPMaxErrorINTEL");
626628
// From spirv_internal.hpp
627629
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");
628630
add(internal::CapabilityOptNoneINTEL, "OptNoneINTEL");

spirv-headers-tag.conf

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9b527c0fb60124936d0906d44803bec51a0200fb
1+
51b106461707f46d962554efe1bf56dee28958a3

0 commit comments

Comments
 (0)