Skip to content

Commit 63c61d8

Browse files
authored
[SYCL] Add support for JIT-ing in AMD and NVIDIA backends (#14280)
This patch provides the following: * support for JIT compilation of Nvidia and AMD kernels This is guarded by `SYCL_JIT_KERNELS` environment variable. Target CPU and features can be provided through environment variables (`SYCL_JIT_TARGET_CPU` and `SYCL_JIT_TARGET_FEATURES` respectively), otherwise default, per-backend, values will be chosen. * caching of JIT-compiled kernels The runtime maintains a map of available JIT-ed kernels, accessible through a key, which is constructed from kernel's name and values of specialization constant (if provided). * materialization of specialization Materialization is done through a `SYCLSpecConstMaterializer` pass that receives the values of all specialization constants used by a kernel (from `SYCLSpecConstDataInserter`) and then walks all the uses of implicit kernel argument (`_arg__specialization_constants_buffer`), representing emulated specialization constants, with concrete values, turning them to de-facto compile time constants. This PR extends the work done for kernel fusion and in a similar fashion it requires embedding of IR (`-fsycl-embed-ir`) during the initial compilation.
1 parent d6780ae commit 63c61d8

32 files changed

+1361
-160
lines changed

sycl-fusion/jit-compiler/include/KernelFusion.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,18 @@
1919

2020
namespace jit_compiler {
2121

22-
class FusionResult {
22+
class JITResult {
2323
public:
24-
explicit FusionResult(const char *ErrorMessage)
25-
: Type{FusionResultType::FAILED}, KernelInfo{},
26-
ErrorMessage{ErrorMessage} {}
24+
explicit JITResult(const char *ErrorMessage)
25+
: Type{JITResultType::FAILED}, KernelInfo{}, ErrorMessage{ErrorMessage} {}
2726

28-
explicit FusionResult(const SYCLKernelInfo &KernelInfo, bool Cached = false)
29-
: Type{(Cached) ? FusionResultType::CACHED : FusionResultType::NEW},
27+
explicit JITResult(const SYCLKernelInfo &KernelInfo, bool Cached = false)
28+
: Type{(Cached) ? JITResultType::CACHED : JITResultType::NEW},
3029
KernelInfo(KernelInfo), ErrorMessage{} {}
3130

32-
bool failed() const { return Type == FusionResultType::FAILED; }
31+
bool failed() const { return Type == JITResultType::FAILED; }
3332

34-
bool cached() const { return Type == FusionResultType::CACHED; }
33+
bool cached() const { return Type == JITResultType::CACHED; }
3534

3635
const char *getErrorMessage() const {
3736
assert(failed() && "No error message present");
@@ -44,9 +43,9 @@ class FusionResult {
4443
}
4544

4645
private:
47-
enum class FusionResultType { FAILED, CACHED, NEW };
46+
enum class JITResultType { FAILED, CACHED, NEW };
4847

49-
FusionResultType Type;
48+
JITResultType Type;
5049
SYCLKernelInfo KernelInfo;
5150
sycl::detail::string ErrorMessage;
5251
};
@@ -56,12 +55,18 @@ extern "C" {
5655
#ifdef __clang__
5756
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
5857
#endif // __clang__
59-
FusionResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
60-
const char *FusedKernelName,
61-
View<ParameterIdentity> Identities,
62-
BarrierFlags BarriersFlags,
63-
View<ParameterInternalization> Internalization,
64-
View<jit_compiler::JITConstant> JITConstants);
58+
JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
59+
const char *FusedKernelName,
60+
View<ParameterIdentity> Identities,
61+
BarrierFlags BarriersFlags,
62+
View<ParameterInternalization> Internalization,
63+
View<jit_compiler::JITConstant> JITConstants);
64+
65+
JITResult materializeSpecConstants(const char *KernelName,
66+
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
67+
View<unsigned char> SpecConstBlob,
68+
const char *TargetCPU,
69+
const char *TargetFeatures);
6570

6671
/// Clear all previously set options.
6772
void resetJITConfiguration();

sycl-fusion/jit-compiler/ld-version-script.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
global:
33
/* Export the library entry points */
44
fuseKernels;
5+
materializeSpecConstants;
56
resetJITConfiguration;
67
addToJITConfiguration;
78

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ using namespace jit_compiler;
2424
using FusedFunction = helper::FusionHelper::FusedFunction;
2525
using FusedFunctionList = std::vector<FusedFunction>;
2626

27-
static FusionResult errorToFusionResult(llvm::Error &&Err,
28-
const std::string &Msg) {
27+
static JITResult errorToFusionResult(llvm::Error &&Err,
28+
const std::string &Msg) {
2929
std::stringstream ErrMsg;
3030
ErrMsg << Msg << "\nDetailed information:\n";
3131
llvm::handleAllErrors(std::move(Err),
@@ -34,7 +34,7 @@ static FusionResult errorToFusionResult(llvm::Error &&Err,
3434
// compiled without exception support.
3535
ErrMsg << "\t" << StrErr.getMessage() << "\n";
3636
});
37-
return FusionResult{ErrMsg.str().c_str()};
37+
return JITResult{ErrMsg.str().c_str()};
3838
}
3939

4040
static std::vector<jit_compiler::NDRange>
@@ -70,11 +70,58 @@ static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
7070
}
7171
}
7272

73-
extern "C" FusionResult
74-
fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
75-
View<ParameterIdentity> Identities, BarrierFlags BarriersFlags,
76-
View<ParameterInternalization> Internalization,
77-
View<jit_compiler::JITConstant> Constants) {
73+
extern "C" JITResult
74+
materializeSpecConstants(const char *KernelName,
75+
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
76+
View<unsigned char> SpecConstBlob,
77+
const char *TargetCPU, const char *TargetFeatures) {
78+
auto &JITCtx = JITContext::getInstance();
79+
80+
TargetInfo TargetInfo = ConfigHelper::get<option::JITTargetInfo>();
81+
BinaryFormat TargetFormat = TargetInfo.getFormat();
82+
if (TargetFormat != BinaryFormat::PTX &&
83+
TargetFormat != BinaryFormat::AMDGCN) {
84+
return JITResult("Output target format not supported by this build. "
85+
"Available targets are: PTX or AMDGCN.");
86+
}
87+
88+
::jit_compiler::SYCLKernelInfo KernelInfo{
89+
KernelName, ::jit_compiler::SYCLArgumentDescriptor{},
90+
::jit_compiler::NDRange{}, BinInfo};
91+
SYCLModuleInfo ModuleInfo;
92+
ModuleInfo.kernels().insert(ModuleInfo.kernels().end(), KernelInfo);
93+
// Load all input kernels from their respective modules into a single
94+
// LLVM IR module.
95+
llvm::Expected<std::unique_ptr<llvm::Module>> ModOrError =
96+
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
97+
ModuleInfo.kernels());
98+
if (auto Error = ModOrError.takeError()) {
99+
return errorToFusionResult(std::move(Error), "Failed to load kernels");
100+
}
101+
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
102+
if (!fusion::FusionPipeline::runMaterializerPasses(
103+
*NewMod, SpecConstBlob.to<llvm::ArrayRef>()) ||
104+
!NewMod->getFunction(KernelName)) {
105+
return JITResult{"Materializer passes should not fail"};
106+
}
107+
108+
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
109+
if (auto Error = translation::KernelTranslator::translateKernel(
110+
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat, TargetCPU,
111+
TargetFeatures)) {
112+
return errorToFusionResult(std::move(Error),
113+
"Translation to output format failed");
114+
}
115+
116+
return JITResult{MaterializerKernelInfo};
117+
}
118+
119+
extern "C" JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
120+
const char *FusedKernelName,
121+
View<ParameterIdentity> Identities,
122+
BarrierFlags BarriersFlags,
123+
View<ParameterInternalization> Internalization,
124+
View<jit_compiler::JITConstant> Constants) {
78125

79126
std::vector<std::string> KernelsToFuse;
80127
llvm::transform(KernelInformation, std::back_inserter(KernelsToFuse),
@@ -93,8 +140,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
93140
}
94141

95142
if (!isTargetFormatSupported(TargetFormat)) {
96-
return FusionResult(
97-
"Fusion output target format not supported by this build");
143+
return JITResult("Fusion output target format not supported by this build");
98144
}
99145

100146
auto &JITCtx = JITContext::getInstance();
@@ -117,7 +163,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
117163
// before returning the kernel info to the runtime.
118164
CachedKernel->NDR = FusedNDR->getNDR();
119165
}
120-
return FusionResult{*CachedKernel, /*Cached*/ true};
166+
return JITResult{*CachedKernel, /*Cached*/ true};
121167
}
122168
helper::printDebugMessage(
123169
"Compiling new kernel, no suitable cached kernel found");
@@ -165,13 +211,13 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
165211
BarriersFlags);
166212

167213
if (!NewMod->getFunction(FusedKernelName)) {
168-
return FusionResult{"Kernel fusion failed"};
214+
return JITResult{"Kernel fusion failed"};
169215
}
170216

171217
// Get the updated kernel info for the fused kernel and add the information to
172218
// the existing KernelInfo.
173219
if (!NewModInfo->hasKernelFor(FusedKernelName)) {
174-
return FusionResult{"No KernelInfo for fused kernel"};
220+
return JITResult{"No KernelInfo for fused kernel"};
175221
}
176222

177223
SYCLKernelInfo &FusedKernelInfo = *NewModInfo->getKernelFor(FusedKernelName);
@@ -188,7 +234,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
188234
JITCtx.addCacheEntry(CacheKey, FusedKernelInfo);
189235
}
190236

191-
return FusionResult{FusedKernelInfo};
237+
return JITResult{FusedKernelInfo};
192238
}
193239

194240
extern "C" void resetJITConfiguration() { ConfigHelper::reset(); }

sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "helper/ConfigHelper.h"
1313
#include "internalization/Internalization.h"
1414
#include "kernel-fusion/SYCLKernelFusion.h"
15+
#include "kernel-fusion/SYCLSpecConstMaterializer.h"
1516
#include "kernel-info/SYCLKernelInfo.h"
1617
#include "syclcp/SYCLCP.h"
1718

@@ -141,3 +142,48 @@ FusionPipeline::runFusionPasses(Module &Mod, SYCLModuleInfo &InputInfo,
141142

142143
return std::make_unique<SYCLModuleInfo>(std::move(*NewModInfo.ModuleInfo));
143144
}
145+
146+
bool FusionPipeline::runMaterializerPasses(
147+
llvm::Module &Mod, llvm::ArrayRef<unsigned char> SpecConstData) {
148+
PassBuilder PB;
149+
LoopAnalysisManager LAM;
150+
FunctionAnalysisManager FAM;
151+
CGSCCAnalysisManager CGAM;
152+
ModuleAnalysisManager MAM;
153+
PB.registerModuleAnalyses(MAM);
154+
PB.registerCGSCCAnalyses(CGAM);
155+
PB.registerFunctionAnalyses(FAM);
156+
PB.registerLoopAnalyses(LAM);
157+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
158+
159+
ModulePassManager MPM;
160+
// Register inserter and materializer passes.
161+
{
162+
FunctionPassManager FPM;
163+
MPM.addPass(SYCLSpecConstDataInserter{SpecConstData});
164+
FPM.addPass(SYCLSpecConstMaterializer{});
165+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
166+
}
167+
// Add generic optimizations,
168+
{
169+
FunctionPassManager FPM;
170+
MPM.addPass(AlwaysInlinerPass{});
171+
FPM.addPass(SROAPass{SROAOptions::ModifyCFG});
172+
FPM.addPass(SCCPPass{});
173+
FPM.addPass(ADCEPass{});
174+
FPM.addPass(EarlyCSEPass{/*UseMemorySSA*/ true});
175+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
176+
}
177+
// followed by unrolling.
178+
{
179+
FunctionPassManager FPM;
180+
FPM.addPass(createFunctionToLoopPassAdaptor(IndVarSimplifyPass{}));
181+
LoopUnrollOptions UnrollOptions;
182+
FPM.addPass(LoopUnrollPass{UnrollOptions});
183+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
184+
}
185+
186+
MPM.run(Mod, MAM);
187+
188+
return true;
189+
}

sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ class FusionPipeline {
2727
static std::unique_ptr<SYCLModuleInfo>
2828
runFusionPasses(llvm::Module &Mod, SYCLModuleInfo &InputInfo,
2929
BarrierFlags BarriersFlags);
30+
31+
///
32+
/// Run the necessary passes in a custom pass pipeline to perform
33+
/// materialization of kernel specialization constants.
34+
static bool
35+
runMaterializerPasses(llvm::Module &Mod,
36+
llvm::ArrayRef<unsigned char> SpecConstData);
3037
};
3138
} // namespace fusion
3239
} // namespace jit_compiler

0 commit comments

Comments
 (0)