Skip to content

Commit f8816dc

Browse files
[Backport to 14] fpbuiltin-max-error support
Changes were cherry-picked from the following commit: c6fe12b This changes add SPIR-V translator support for the SPIR-V extension documented here: KhronosGroup/SPIRV-Registry#193. This extension adds one decoration to represent maximum error for FP operations and adds the related Capability. SPIRV Headers support for representing this in SPIR-V: KhronosGroup/SPIRV-Headers#363 intel/llvm#8134 added a new call-site attribute associated with FP builtin intrinsics. This attribute is named 'fpbuiltin-max-error'. Following example shows how this extension is supported in the translator. The input LLVM IR uses new LLVM builtin calls to represent FP operations. An attribute named 'fpbuiltin-max-error' is used to represent the max-error allowed in the FP operation. Example Input LLVM: %t6 = call float @llvm.fpbuiltin.sin.f32(float %f1) KhronosGroup#2 attributes KhronosGroup#2 = { "fpbuiltin-max-error"="2.5" } This is translated into a SPIR-V instruction (for add/sub/mul/div/rem) and OpenCl extended instruction for other instructions. A decoration to represent the max-error is attached to the SPIR-V instruction. SPIR-V code: 4 Decorate 97 FPMaxErrorDecorationINTEL 1075838976 6 ExtInst 2 97 1 sin 88 No new support is added to support translating this SPIR_V back to LLVM. Existing support is used. The decoration is translated back into named metadata associated with the LLVM instruction. This can be readily consumed by backends. Based on input from @andykaylor, we emit attributes when the FP operation is translated back to a call to a builtin function and emit metadata otherwise. Translated LLVM code for basic math functions (add/sub/mul/div/rem): %t6 = fmul float %f1, %f2, !fpbuiltin-max-error !7 !7 = !{!"2.500000"} Translated LLVM code for other math functions: %t6 = call spir_func float @_Z3sinf(float %f1) KhronosGroup#3 attributes KhronosGroup#3 = { "fpbuiltin-max-error"="4.000000" }
1 parent 0141f3d commit f8816dc

File tree

9 files changed

+463
-0
lines changed

9 files changed

+463
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,5 @@ EXT(SPV_INTEL_masked_gather_scatter)
6363
EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
6464
EXT(SPV_INTEL_tensor_float32_rounding)
6565
EXT(SPV_EXT_relaxed_printf_string_address_space)
66+
EXT(SPV_INTEL_fp_max_error)
6667
EXT(SPV_INTEL_cache_controls)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3819,7 +3819,48 @@ void SPIRVToLLVM::transDecorationsToMetadata(SPIRVValue *BV, Value *V) {
38193819
SetDecorationsMetadata(I);
38203820
}
38213821

3822+
namespace {
3823+
3824+
static float convertSPIRVWordToFloat(SPIRVWord Spir) {
3825+
union {
3826+
float F;
3827+
SPIRVWord Spir;
3828+
} FPMaxError;
3829+
FPMaxError.Spir = Spir;
3830+
return FPMaxError.F;
3831+
}
3832+
3833+
static bool transFPMaxErrorDecoration(SPIRVValue *BV, Value *V,
3834+
LLVMContext *Context) {
3835+
SPIRVWord ID;
3836+
if (Instruction *I = dyn_cast<Instruction>(V))
3837+
if (BV->hasDecorate(DecorationFPMaxErrorDecorationINTEL, 0, &ID)) {
3838+
auto Literals =
3839+
BV->getDecorationLiterals(DecorationFPMaxErrorDecorationINTEL);
3840+
assert(Literals.size() == 1 &&
3841+
"FP Max Error decoration shall have 1 operand");
3842+
auto F = convertSPIRVWordToFloat(Literals[0]);
3843+
if (CallInst *CI = dyn_cast<CallInst>(I)) {
3844+
// Add attribute
3845+
auto A = llvm::Attribute::get(*Context, "fpbuiltin-max-error",
3846+
std::to_string(F));
3847+
CI->addFnAttr(A);
3848+
} else {
3849+
// Add metadata
3850+
MDNode *N =
3851+
MDNode::get(*Context, MDString::get(*Context, std::to_string(F)));
3852+
I->setMetadata("fpbuiltin-max-error", N);
3853+
}
3854+
return true;
3855+
}
3856+
return false;
3857+
}
3858+
} // namespace
3859+
38223860
bool SPIRVToLLVM::transDecoration(SPIRVValue *BV, Value *V) {
3861+
if (transFPMaxErrorDecoration(BV, V, Context))
3862+
return true;
3863+
38233864
if (!transAlign(BV, V))
38243865
return false;
38253866

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,8 @@ CallInst *mutateCallInst(
766766
auto NewCI = addCallInst(M, NewName, CI->getType(), Args, Attrs, CI, Mangle,
767767
InstName, TakeFuncName);
768768
NewCI->setDebugLoc(CI->getDebugLoc());
769+
NewCI->copyMetadata(*CI);
770+
NewCI->setAttributes(CI->getAttributes());
769771
LLVM_DEBUG(dbgs() << " => " << *NewCI << '\n');
770772
CI->replaceAllUsesWith(NewCI);
771773
CI->eraseFromParent();

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ using namespace llvm;
105105
using namespace SPIRV;
106106
using namespace OCLUtil;
107107

108+
namespace {
109+
110+
static SPIRVWord convertFloatToSPIRVWord(float F) {
111+
union {
112+
float F;
113+
SPIRVWord Spir;
114+
} FPMaxError;
115+
FPMaxError.F = F;
116+
return FPMaxError.Spir;
117+
}
118+
119+
} // namespace
120+
108121
namespace SPIRV {
109122

110123
static void foreachKernelArgMD(
@@ -3396,6 +3409,26 @@ bool LLVMToSPIRVBase::isKnownIntrinsic(Intrinsic::ID Id) {
33963409
}
33973410
}
33983411

3412+
// Add decoration if needed
3413+
SPIRVInstruction *addFPBuiltinDecoration(SPIRVModule *BM, IntrinsicInst *II,
3414+
SPIRVInstruction *I) {
3415+
const bool AllowFPMaxError =
3416+
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_max_error);
3417+
assert(II->getCalledFunction()->getName().startswith("llvm.fpbuiltin"));
3418+
// Add a new decoration for llvm.builtin intrinsics, if needed
3419+
if (AllowFPMaxError)
3420+
if (II->getAttributes().hasFnAttr("fpbuiltin-max-error")) {
3421+
double F = 0.0;
3422+
II->getAttributes()
3423+
.getFnAttr("fpbuiltin-max-error")
3424+
.getValueAsString()
3425+
.getAsDouble(F);
3426+
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
3427+
convertFloatToSPIRVWord(F));
3428+
}
3429+
return I;
3430+
}
3431+
33993432
// Performs mapping of LLVM IR rounding mode to SPIR-V rounding mode
34003433
// Value *V is metadata <rounding mode> argument of
34013434
// llvm.experimental.constrained.* intrinsics
@@ -4090,6 +4123,8 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
40904123
}
40914124

40924125
default:
4126+
if (auto *BVar = transFPBuiltinIntrinsicInst(II, BB))
4127+
return BVar;
40934128
if (BM->isUnknownIntrinsicAllowed(II))
40944129
return BM->addCallInst(
40954130
transFunctionDecl(II->getCalledFunction()),
@@ -4105,6 +4140,124 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
41054140
return nullptr;
41064141
}
41074142

4143+
LLVMToSPIRVBase::FPBuiltinType
4144+
LLVMToSPIRVBase::getFPBuiltinType(IntrinsicInst *II, StringRef &OpName) {
4145+
StringRef Name = II->getCalledFunction()->getName();
4146+
if (!Name.startswith("llvm.fpbuiltin"))
4147+
return FPBuiltinType::UNKNOWN;
4148+
Name.consume_front("llvm.fpbuiltin.");
4149+
OpName = Name.split('.').first;
4150+
FPBuiltinType Type =
4151+
StringSwitch<FPBuiltinType>(OpName)
4152+
.Cases("fadd", "fsub", "fmul", "fdiv", "frem",
4153+
FPBuiltinType::REGULAR_MATH)
4154+
.Cases("sin", "cos", "tan", FPBuiltinType::EXT_1OPS)
4155+
.Cases("sinh", "cosh", "tanh", FPBuiltinType::EXT_1OPS)
4156+
.Cases("asin", "acos", "atan", FPBuiltinType::EXT_1OPS)
4157+
.Cases("asinh", "acosh", "atanh", FPBuiltinType::EXT_1OPS)
4158+
.Cases("exp", "exp2", "exp10", "expm1", FPBuiltinType::EXT_1OPS)
4159+
.Cases("log", "log2", "log10", "log1p", FPBuiltinType::EXT_1OPS)
4160+
.Cases("sqrt", "rsqrt", "erf", "erfc", FPBuiltinType::EXT_1OPS)
4161+
.Cases("atan2", "pow", "hypot", "ldexp", FPBuiltinType::EXT_2OPS)
4162+
.Case("sincos", FPBuiltinType::EXT_3OPS)
4163+
.Default(FPBuiltinType::UNKNOWN);
4164+
return Type;
4165+
}
4166+
4167+
SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
4168+
SPIRVBasicBlock *BB) {
4169+
StringRef OpName;
4170+
auto FPBuiltinTypeVal = getFPBuiltinType(II, OpName);
4171+
if (FPBuiltinTypeVal == FPBuiltinType::UNKNOWN)
4172+
return nullptr;
4173+
switch (FPBuiltinTypeVal) {
4174+
case FPBuiltinType::REGULAR_MATH: {
4175+
auto BinOp = StringSwitch<Op>(OpName)
4176+
.Case("fadd", OpFAdd)
4177+
.Case("fsub", OpFSub)
4178+
.Case("fmul", OpFMul)
4179+
.Case("fdiv", OpFDiv)
4180+
.Case("frem", OpFRem)
4181+
.Default(OpUndef);
4182+
auto *BI = BM->addBinaryInst(BinOp, transType(II->getType()),
4183+
transValue(II->getArgOperand(0), BB),
4184+
transValue(II->getArgOperand(1), BB), BB);
4185+
return addFPBuiltinDecoration(BM, II, BI);
4186+
}
4187+
case FPBuiltinType::EXT_1OPS: {
4188+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4189+
break;
4190+
SPIRVType *STy = transType(II->getType());
4191+
std::vector<SPIRVValue *> Ops(1, transValue(II->getArgOperand(0), BB));
4192+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4193+
.Case("sin", OpenCLLIB::Sin)
4194+
.Case("cos", OpenCLLIB::Cos)
4195+
.Case("tan", OpenCLLIB::Tan)
4196+
.Case("sinh", OpenCLLIB::Sinh)
4197+
.Case("cosh", OpenCLLIB::Cosh)
4198+
.Case("tanh", OpenCLLIB::Tanh)
4199+
.Case("asin", OpenCLLIB::Asin)
4200+
.Case("acos", OpenCLLIB::Acos)
4201+
.Case("atan", OpenCLLIB::Atan)
4202+
.Case("asinh", OpenCLLIB::Asinh)
4203+
.Case("acosh", OpenCLLIB::Acosh)
4204+
.Case("atanh", OpenCLLIB::Atanh)
4205+
.Case("exp", OpenCLLIB::Exp)
4206+
.Case("exp2", OpenCLLIB::Exp2)
4207+
.Case("exp10", OpenCLLIB::Exp10)
4208+
.Case("expm1", OpenCLLIB::Expm1)
4209+
.Case("log", OpenCLLIB::Log)
4210+
.Case("log2", OpenCLLIB::Log2)
4211+
.Case("log10", OpenCLLIB::Log10)
4212+
.Case("log1p", OpenCLLIB::Log1p)
4213+
.Case("sqrt", OpenCLLIB::Sqrt)
4214+
.Case("rsqrt", OpenCLLIB::Rsqrt)
4215+
.Case("erf", OpenCLLIB::Erf)
4216+
.Case("erfc", OpenCLLIB::Erfc)
4217+
.Default(SPIRVWORD_MAX);
4218+
assert(ExtOp != SPIRVWORD_MAX);
4219+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4220+
Ops, BB);
4221+
return addFPBuiltinDecoration(BM, II, BI);
4222+
}
4223+
case FPBuiltinType::EXT_2OPS: {
4224+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4225+
break;
4226+
SPIRVType *STy = transType(II->getType());
4227+
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
4228+
transValue(II->getArgOperand(1), BB)};
4229+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4230+
.Case("atan2", OpenCLLIB::Atan2)
4231+
.Case("hypot", OpenCLLIB::Hypot)
4232+
.Case("pow", OpenCLLIB::Pow)
4233+
.Case("ldexp", OpenCLLIB::Ldexp)
4234+
.Default(SPIRVWORD_MAX);
4235+
assert(ExtOp != SPIRVWORD_MAX);
4236+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4237+
Ops, BB);
4238+
return addFPBuiltinDecoration(BM, II, BI);
4239+
}
4240+
case FPBuiltinType::EXT_3OPS: {
4241+
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
4242+
break;
4243+
SPIRVType *STy = transType(II->getType());
4244+
std::vector<SPIRVValue *> Ops{transValue(II->getArgOperand(0), BB),
4245+
transValue(II->getArgOperand(1), BB),
4246+
transValue(II->getArgOperand(2), BB)};
4247+
auto ExtOp = StringSwitch<SPIRVWord>(OpName)
4248+
.Case("sincos", OpenCLLIB::Sincos)
4249+
.Default(SPIRVWORD_MAX);
4250+
assert(ExtOp != SPIRVWORD_MAX);
4251+
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
4252+
Ops, BB);
4253+
return addFPBuiltinDecoration(BM, II, BI);
4254+
}
4255+
default:
4256+
return nullptr;
4257+
}
4258+
return nullptr;
4259+
}
4260+
41084261
SPIRVValue *LLVMToSPIRVBase::transFenceInst(FenceInst *FI,
41094262
SPIRVBasicBlock *BB) {
41104263
SPIRVWord MemorySemantics;

lib/SPIRV/SPIRVWriter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ class LLVMToSPIRVBase {
108108
bool transWorkItemBuiltinCallsToVariables();
109109
bool isKnownIntrinsic(Intrinsic::ID Id);
110110
SPIRVValue *transIntrinsicInst(IntrinsicInst *Intrinsic, SPIRVBasicBlock *BB);
111+
enum class FPBuiltinType {
112+
REGULAR_MATH,
113+
EXT_1OPS,
114+
EXT_2OPS,
115+
EXT_3OPS,
116+
UNKNOWN
117+
};
118+
FPBuiltinType getFPBuiltinType(IntrinsicInst *II, StringRef &);
119+
SPIRVValue *transFPBuiltinIntrinsicInst(IntrinsicInst *II,
120+
SPIRVBasicBlock *BB);
111121
SPIRVValue *transFenceInst(FenceInst *FI, SPIRVBasicBlock *BB);
112122
SPIRVValue *transCallInst(CallInst *Call, SPIRVBasicBlock *BB);
113123
SPIRVValue *transDirectCallInst(CallInst *Call, SPIRVBasicBlock *BB);

lib/SPIRV/libSPIRV/SPIRVDecorate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class SPIRVDecorate : public SPIRVDecorateGeneric {
178178
case internal::DecorationInitModeINTEL:
179179
case internal::DecorationImplementInCSRINTEL:
180180
return ExtensionID::SPV_INTEL_global_variable_decorations;
181+
case DecorationFPMaxErrorDecorationINTEL:
182+
return ExtensionID::SPV_INTEL_fp_max_error;
181183
case internal::DecorationCacheControlLoadINTEL:
182184
case internal::DecorationCacheControlStoreINTEL:
183185
return ExtensionID::SPV_INTEL_cache_controls;

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ template <> inline void SPIRVMap<Decoration, SPIRVCapVec>::init() {
470470
{internal::CapabilityCacheControlsINTEL});
471471
ADD_VEC_INIT(internal::DecorationCacheControlStoreINTEL,
472472
{internal::CapabilityCacheControlsINTEL});
473+
ADD_VEC_INIT(DecorationFPMaxErrorDecorationINTEL,
474+
{CapabilityFPMaxErrorINTEL});
473475
}
474476

475477
template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ template <> inline void SPIRVMap<Decoration, std::string>::init() {
179179
add(DecorationMediaBlockIOINTEL, "MediaBlockIOINTEL");
180180
add(DecorationAliasScopeINTEL, "AliasScopeINTEL");
181181
add(DecorationNoAliasINTEL, "NoAliasINTEL");
182+
add(DecorationFPMaxErrorDecorationINTEL, "FPMaxErrorDecorationINTEL");
182183

183184
// From spirv_internal.hpp
184185
add(internal::DecorationCallableFunctionINTEL, "CallableFunctionINTEL");
@@ -601,6 +602,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
601602
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
602603
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");
603604
add(CapabilityGroupUniformArithmeticKHR, "GroupUniformArithmeticKHR");
605+
add(CapabilityFPMaxErrorINTEL, "FPMaxErrorINTEL");
604606

605607
// From spirv_internal.hpp
606608
add(internal::CapabilityFPGADSPControlINTEL, "FPGADSPControlINTEL");

0 commit comments

Comments
 (0)