@@ -744,56 +744,26 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
744
744
if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
745
745
SrcTy = VecTy->getElementType ();
746
746
auto IsTargetInt = isa<IntegerType>(TargetTy);
747
- auto TargetSigned = DemangledName[8 ] != ' u' ;
748
747
749
- std::string TargetTyName (
750
- DemangledName.substr (strlen (kOCLBuiltinName ::ConvertPrefix)));
751
- auto FirstUnderscoreLoc = TargetTyName.find (' _' );
752
- if (FirstUnderscoreLoc != std::string::npos)
753
- TargetTyName = TargetTyName.substr (0 , FirstUnderscoreLoc);
754
-
755
- // Validate target type name
756
- std::regex Expr (" ([a-z]+)([0-9]*)$" );
748
+ // Validate conversion function name and vector size if present
749
+ std::regex Expr (
750
+ " convert_(float|double|half|u?char|u?short|u?int|u?long)(2|3|4|8|16)*"
751
+ " (_sat)*(_rt[ezpn])*$" );
757
752
std::smatch DestTyMatch;
758
- if (!std::regex_match (TargetTyName, DestTyMatch, Expr))
753
+ std::string ConversionFunc (DemangledName.str ());
754
+ if (!std::regex_match (ConversionFunc, DestTyMatch, Expr))
759
755
return ;
760
756
761
757
// The first sub_match is the whole string; the next
762
- // sub_match is the first parenthesized expression.
763
- std::string DestTy = DestTyMatch[1 ].str ();
764
-
765
- // check it's valid type name
766
- static std::unordered_set<std::string> ValidTypes = {
767
- " float" , " double" , " half" , " char" , " uchar" , " short" ,
768
- " ushort" , " int" , " uint" , " long" , " ulong" };
769
-
770
- if (ValidTypes.find (DestTy) == ValidTypes.end ())
771
- return ;
772
-
773
- // check that it's allowed vector size
774
- std::string VecSize = DestTyMatch[2 ].str ();
775
- if (!VecSize.empty ()) {
776
- int Size = stoi (VecSize);
777
- switch (Size ) {
778
- case 2 :
779
- case 3 :
780
- case 4 :
781
- case 8 :
782
- case 16 :
783
- break ;
784
- default :
785
- return ;
786
- }
787
- }
788
- DemangledName = DemangledName.drop_front (
789
- strlen (kOCLBuiltinName ::ConvertPrefix) + TargetTyName.size ());
790
- TargetTyName = std::string (" _R" ) + TargetTyName;
758
+ // sub_matches are the parenthesized expressions.
759
+ enum { TypeIdx = 1 , VecSizeIdx = 2 , SatIdx = 3 , RoundingIdx = 4 };
760
+ std::string DestTy = DestTyMatch[TypeIdx].str ();
761
+ std::string VecSize = DestTyMatch[VecSizeIdx].str ();
762
+ std::string Sat = DestTyMatch[SatIdx].str ();
763
+ std::string Rounding = DestTyMatch[RoundingIdx].str ();
791
764
792
- if (!DemangledName.empty () && !DemangledName.starts_with (" _sat" ) &&
793
- !DemangledName.starts_with (" _rt" ))
794
- return ;
765
+ bool TargetSigned = DestTy[0 ] != ' u' ;
795
766
796
- std::string Sat = DemangledName.find (" _sat" ) != StringRef::npos ? " _sat" : " " ;
797
767
if (isa<IntegerType>(SrcTy)) {
798
768
bool Signed = isLastFuncParamSigned (MangledName);
799
769
if (IsTargetInt) {
@@ -810,13 +780,13 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
810
780
} else
811
781
OC = OpFConvert;
812
782
}
813
- auto Loc = DemangledName.find (" _rt" );
814
- std::string Rounding;
815
- if (Loc != StringRef::npos && !(isa<IntegerType>(SrcTy) && IsTargetInt)) {
816
- Rounding = DemangledName.substr (Loc, 4 ).str ();
817
- }
783
+
784
+ if (!Rounding.empty () && (isa<IntegerType>(SrcTy) && IsTargetInt))
785
+ return ;
786
+
818
787
assert (CI->getCalledFunction () && " Unexpected indirect call" );
819
- mutateCallInst (CI, getSPIRVFuncName (OC, TargetTyName + Sat + Rounding));
788
+ mutateCallInst (
789
+ CI, getSPIRVFuncName (OC, " _R" + DestTy + VecSize + Sat + Rounding));
820
790
}
821
791
822
792
void OCLToSPIRVBase::visitCallGroupBuiltin (CallInst *CI,
0 commit comments