Skip to content

Commit 0414482

Browse files
authored
Rewrite OpenCL explicit conversion builtins handling (#2464)
Check that builtin is valid mostly by matching the regular expression - remove unneeded checks. This is a follow-up to #2443
1 parent 5242c38 commit 0414482

File tree

2 files changed

+41
-52
lines changed

2 files changed

+41
-52
lines changed

lib/SPIRV/OCLToSPIRV.cpp

+19-49
Original file line numberDiff line numberDiff line change
@@ -744,56 +744,26 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
744744
if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
745745
SrcTy = VecTy->getElementType();
746746
auto IsTargetInt = isa<IntegerType>(TargetTy);
747-
auto TargetSigned = DemangledName[8] != 'u';
748747

749-
std::string TargetTyName(
750-
DemangledName.substr(strlen(kOCLBuiltinName::ConvertPrefix)));
751-
auto FirstUnderscoreLoc = TargetTyName.find('_');
752-
if (FirstUnderscoreLoc != std::string::npos)
753-
TargetTyName = TargetTyName.substr(0, FirstUnderscoreLoc);
754-
755-
// Validate target type name
756-
std::regex Expr("([a-z]+)([0-9]*)$");
748+
// Validate conversion function name and vector size if present
749+
std::regex Expr(
750+
"convert_(float|double|half|u?char|u?short|u?int|u?long)(2|3|4|8|16)*"
751+
"(_sat)*(_rt[ezpn])*$");
757752
std::smatch DestTyMatch;
758-
if (!std::regex_match(TargetTyName, DestTyMatch, Expr))
753+
std::string ConversionFunc(DemangledName.str());
754+
if (!std::regex_match(ConversionFunc, DestTyMatch, Expr))
759755
return;
760756

761757
// The first sub_match is the whole string; the next
762-
// sub_match is the first parenthesized expression.
763-
std::string DestTy = DestTyMatch[1].str();
764-
765-
// check it's valid type name
766-
static std::unordered_set<std::string> ValidTypes = {
767-
"float", "double", "half", "char", "uchar", "short",
768-
"ushort", "int", "uint", "long", "ulong"};
769-
770-
if (ValidTypes.find(DestTy) == ValidTypes.end())
771-
return;
772-
773-
// check that it's allowed vector size
774-
std::string VecSize = DestTyMatch[2].str();
775-
if (!VecSize.empty()) {
776-
int Size = stoi(VecSize);
777-
switch (Size) {
778-
case 2:
779-
case 3:
780-
case 4:
781-
case 8:
782-
case 16:
783-
break;
784-
default:
785-
return;
786-
}
787-
}
788-
DemangledName = DemangledName.drop_front(
789-
strlen(kOCLBuiltinName::ConvertPrefix) + TargetTyName.size());
790-
TargetTyName = std::string("_R") + TargetTyName;
758+
// sub_matches are the parenthesized expressions.
759+
enum { TypeIdx = 1, VecSizeIdx = 2, SatIdx = 3, RoundingIdx = 4 };
760+
std::string DestTy = DestTyMatch[TypeIdx].str();
761+
std::string VecSize = DestTyMatch[VecSizeIdx].str();
762+
std::string Sat = DestTyMatch[SatIdx].str();
763+
std::string Rounding = DestTyMatch[RoundingIdx].str();
791764

792-
if (!DemangledName.empty() && !DemangledName.starts_with("_sat") &&
793-
!DemangledName.starts_with("_rt"))
794-
return;
765+
bool TargetSigned = DestTy[0] != 'u';
795766

796-
std::string Sat = DemangledName.find("_sat") != StringRef::npos ? "_sat" : "";
797767
if (isa<IntegerType>(SrcTy)) {
798768
bool Signed = isLastFuncParamSigned(MangledName);
799769
if (IsTargetInt) {
@@ -810,13 +780,13 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
810780
} else
811781
OC = OpFConvert;
812782
}
813-
auto Loc = DemangledName.find("_rt");
814-
std::string Rounding;
815-
if (Loc != StringRef::npos && !(isa<IntegerType>(SrcTy) && IsTargetInt)) {
816-
Rounding = DemangledName.substr(Loc, 4).str();
817-
}
783+
784+
if (!Rounding.empty() && (isa<IntegerType>(SrcTy) && IsTargetInt))
785+
return;
786+
818787
assert(CI->getCalledFunction() && "Unexpected indirect call");
819-
mutateCallInst(CI, getSPIRVFuncName(OC, TargetTyName + Sat + Rounding));
788+
mutateCallInst(
789+
CI, getSPIRVFuncName(OC, "_R" + DestTy + VecSize + Sat + Rounding));
820790
}
821791

822792
void OCLToSPIRVBase::visitCallGroupBuiltin(CallInst *CI,

test/transcoding/OpenCL/convert_functions.ll

+22-3
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111
; RUN: FileCheck < %t.rev.ll %s -check-prefix=CHECK-LLVM
1212

1313
; CHECK-SPIRV: Name [[#Func:]] "_Z18convert_float_func"
14-
; CHECK-SPIRV: TypeVoid [[#VoidTy:]]
15-
; CHECK-SPIRV: TypeFloat [[#FloatTy:]] 32
14+
; CHECK-SPIRV: Name [[#Func1:]] "_Z20convert_uint_satfunc"
15+
; CHECK-SPIRV: Name [[#Func2:]] "_Z21convert_float_rtzfunc"
16+
; CHECK-SPIRV-DAG: TypeVoid [[#VoidTy:]]
17+
; CHECK-SPIRV-DAG: TypeFloat [[#FloatTy:]] 32
1618

1719
; CHECK-SPIRV: Function [[#VoidTy]] [[#Func]]
1820
; CHECK-SPIRV: ConvertSToF [[#FloatTy]] [[#ConvertId:]] [[#]]
1921
; CHECK-SPIRV: FunctionCall [[#VoidTy]] [[#]] [[#Func]] [[#ConvertId]]
22+
; CHECK-SPIRV: FunctionCall [[#VoidTy]] [[#]] [[#Func1]] [[#]]
23+
; CHECK-SPIRV: FunctionCall [[#VoidTy]] [[#]] [[#Func2]] [[#ConvertId]]
2024
; CHECK-SPIRV-NOT: FConvert
25+
; CHECK-SPIRV-NOT: ConvertUToF
2126

2227
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
2328
target triple = "spir"
@@ -30,16 +35,30 @@ entry:
3035
ret void
3136
}
3237

38+
define dso_local spir_func void @_Z20convert_uint_satfunc(i32 noundef %x) #0 {
39+
entry:
40+
ret void
41+
}
42+
43+
define dso_local spir_func void @_Z21convert_float_rtzfunc(float noundef %x) #0 {
44+
entry:
45+
ret void
46+
}
47+
3348
; Function Attrs: convergent noinline norecurse nounwind optnone
3449
define dso_local spir_func void @convert_int_bf16(i32 noundef %x) #0 {
3550
entry:
3651
%x.addr = alloca i32, align 4
3752
store i32 %x, ptr %x.addr, align 4
3853
%0 = load i32, ptr %x.addr, align 4
3954
; CHECK-LLVM: %[[Call:[a-z]+]] = sitofp i32 %[[#]] to float
40-
%call = call spir_func float @_Z13convert_floati(i32 noundef %0) #1
4155
; CHECK-LLVM: call spir_func void @_Z18convert_float_func(float %[[Call]])
56+
; CHECK-LLVM: call spir_func void @_Z20convert_uint_satfunc(i32 %[[#]])
57+
; CHECK-LLVM: call spir_func void @_Z21convert_float_rtzfunc(float %[[Call]])
58+
%call = call spir_func float @_Z13convert_floati(i32 noundef %0) #1
4259
call spir_func void @_Z18convert_float_func(float noundef %call) #0
60+
call spir_func void @_Z20convert_uint_satfunc(i32 noundef %0) #0
61+
call spir_func void @_Z21convert_float_rtzfunc(float noundef %call) #0
4362
ret void
4463
}
4564

0 commit comments

Comments
 (0)