Skip to content

Commit 6dc419f

Browse files
joppermsommerlukas
andauthored
[SYCL][RTC] Cache frontend invocation (#16823)
Adds `sycl_jit`-RTC-specific persistent caching to the runtime. The basic idea is to cache only the LLVM module resulting from the device compilation in bitcode format. Device linking and post-link would be run always (even for a cache hit), as invoking the frontend is the most expensive step in the pipeline right now. The cache key is the concatenation of: - the Base64*-encoding of a BLAKE3 hash of the preprocessed source string (i.e. containing all headers included as virtual files per the `kernel_compiler` extension as well as from the local file system), and - the Base64-encoding of a BLAKE3 hash of the user-supplied build options. *) Replacing `/` by `-` to make the string filesystem-friendly. --------- Signed-off-by: Julian Oppermann <[email protected]> Co-authored-by: Lukas Sommer <[email protected]>
1 parent 47630fe commit 6dc419f

16 files changed

+625
-70
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ struct RTCDevImgInfo {
411411

412412
using RTCBundleInfo = DynArray<RTCDevImgInfo>;
413413

414+
// LLVM's APIs prefer `char *` for byte buffers.
415+
using RTCDeviceCodeIR = DynArray<char>;
416+
414417
} // namespace jit_compiler
415418

416419
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_llvm_library(sycl-jit
1818

1919
LINK_COMPONENTS
2020
BitReader
21+
BitWriter
2122
Core
2223
Support
2324
Option

sycl-jit/jit-compiler/include/KernelFusion.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,45 @@ class JITResult {
5656
sycl::detail::string ErrorMessage;
5757
};
5858

59+
class RTCHashResult {
60+
public:
61+
static RTCHashResult success(const char *Hash) {
62+
return RTCHashResult{/*Failed=*/false, Hash};
63+
}
64+
65+
static RTCHashResult failure(const char *PreprocLog) {
66+
return RTCHashResult{/*Failed=*/true, PreprocLog};
67+
}
68+
69+
bool failed() { return Failed; }
70+
71+
const char *getPreprocLog() {
72+
assert(failed() && "No preprocessor log");
73+
return HashOrLog.c_str();
74+
}
75+
76+
const char *getHash() {
77+
assert(!failed() && "No hash");
78+
return HashOrLog.c_str();
79+
}
80+
81+
private:
82+
RTCHashResult(bool Failed, const char *HashOrLog)
83+
: Failed(Failed), HashOrLog(HashOrLog) {}
84+
85+
bool Failed;
86+
sycl::detail::string HashOrLog;
87+
};
88+
5989
class RTCResult {
6090
public:
6191
explicit RTCResult(const char *BuildLog)
6292
: Failed{true}, BundleInfo{}, BuildLog{BuildLog} {}
6393

64-
RTCResult(RTCBundleInfo &&BundleInfo, const char *BuildLog)
65-
: Failed{false}, BundleInfo{std::move(BundleInfo)}, BuildLog{BuildLog} {}
94+
RTCResult(RTCBundleInfo &&BundleInfo, RTCDeviceCodeIR &&DeviceCodeIR,
95+
const char *BuildLog)
96+
: Failed{false}, BundleInfo{std::move(BundleInfo)},
97+
DeviceCodeIR(std::move(DeviceCodeIR)), BuildLog{BuildLog} {}
6698

6799
bool failed() const { return Failed; }
68100

@@ -73,9 +105,15 @@ class RTCResult {
73105
return BundleInfo;
74106
}
75107

108+
const RTCDeviceCodeIR &getDeviceCodeIR() const {
109+
assert(!failed() && "No device code IR");
110+
return DeviceCodeIR;
111+
}
112+
76113
private:
77114
bool Failed;
78115
RTCBundleInfo BundleInfo;
116+
RTCDeviceCodeIR DeviceCodeIR;
79117
sycl::detail::string BuildLog;
80118
};
81119

@@ -100,9 +138,14 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
100138
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
101139
View<unsigned char> SpecConstBlob);
102140

141+
KF_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
142+
View<InMemoryFile> IncludeFiles,
143+
View<const char *> UserArgs);
144+
103145
KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
104146
View<InMemoryFile> IncludeFiles,
105-
View<const char *> UserArgs);
147+
View<const char *> UserArgs,
148+
View<char> CachedIR, bool SaveIR);
106149

107150
KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address);
108151

sycl-jit/jit-compiler/ld-version-script.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/* Export the library entry points */
44
fuseKernels;
55
materializeSpecConstants;
6+
calculateHash;
67
compileSYCL;
78
destroyBinary;
89
resetJITConfiguration;

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include "translation/SPIRVLLVMTranslation.h"
2020

2121
#include <llvm/ADT/StringExtras.h>
22+
#include <llvm/Bitcode/BitcodeReader.h>
23+
#include <llvm/Bitcode/BitcodeWriter.h>
2224
#include <llvm/Support/Error.h>
25+
#include <llvm/Support/MemoryBuffer.h>
2326
#include <llvm/Support/TimeProfiler.h>
2427

2528
#include <clang/Driver/Options.h>
@@ -31,17 +34,21 @@ using namespace jit_compiler;
3134
using FusedFunction = helper::FusionHelper::FusedFunction;
3235
using FusedFunctionList = std::vector<FusedFunction>;
3336

34-
template <typename ResultType>
35-
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
37+
static std::string formatError(llvm::Error &&Err, const std::string &Msg) {
3638
std::stringstream ErrMsg;
3739
ErrMsg << Msg << "\nDetailed information:\n";
3840
llvm::handleAllErrors(std::move(Err),
3941
[&ErrMsg](const llvm::StringError &StrErr) {
40-
// Cannot throw an exception here if LLVM itself is
41-
// compiled without exception support.
4242
ErrMsg << "\t" << StrErr.getMessage() << "\n";
4343
});
44-
return ResultType{ErrMsg.str().c_str()};
44+
return ErrMsg.str();
45+
}
46+
47+
template <typename ResultType>
48+
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
49+
// Cannot throw an exception here if LLVM itself is compiled without exception
50+
// support.
51+
return ResultType{formatError(std::move(Err), Msg).c_str()};
4552
}
4653

4754
static std::vector<jit_compiler::NDRange>
@@ -240,10 +247,42 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
240247
return JITResult{FusedKernelInfo};
241248
}
242249

250+
extern "C" KF_EXPORT_SYMBOL RTCHashResult
251+
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
252+
View<const char *> UserArgs) {
253+
auto UserArgListOrErr = parseUserArgs(UserArgs);
254+
if (!UserArgListOrErr) {
255+
return RTCHashResult::failure(
256+
formatError(UserArgListOrErr.takeError(),
257+
"Parsing of user arguments failed")
258+
.c_str());
259+
}
260+
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);
261+
262+
auto Start = std::chrono::high_resolution_clock::now();
263+
auto HashOrError = calculateHash(SourceFile, IncludeFiles, UserArgList);
264+
if (!HashOrError) {
265+
return RTCHashResult::failure(
266+
formatError(HashOrError.takeError(), "Hashing failed").c_str());
267+
}
268+
auto Hash = *HashOrError;
269+
auto Stop = std::chrono::high_resolution_clock::now();
270+
271+
if (UserArgList.hasArg(clang::driver::options::OPT_ftime_trace_EQ)) {
272+
std::chrono::duration<double, std::milli> HashTime = Stop - Start;
273+
llvm::dbgs() << "Hashing of " << SourceFile.Path << " took "
274+
<< int(HashTime.count()) << " ms\n";
275+
}
276+
277+
return RTCHashResult::success(Hash.c_str());
278+
}
279+
243280
extern "C" KF_EXPORT_SYMBOL RTCResult
244281
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
245-
View<const char *> UserArgs) {
282+
View<const char *> UserArgs, View<char> CachedIR, bool SaveIR) {
283+
llvm::LLVMContext Context;
246284
std::string BuildLog;
285+
configureDiagnostics(Context, BuildLog);
247286

248287
auto UserArgListOrErr = parseUserArgs(UserArgs);
249288
if (!UserArgListOrErr) {
@@ -272,16 +311,43 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
272311
Verbose);
273312
}
274313

275-
auto ModuleOrErr =
276-
compileDeviceCode(SourceFile, IncludeFiles, UserArgList, BuildLog);
277-
if (!ModuleOrErr) {
278-
return errorTo<RTCResult>(ModuleOrErr.takeError(),
279-
"Device compilation failed");
314+
std::unique_ptr<llvm::Module> Module;
315+
316+
if (CachedIR.size() > 0) {
317+
llvm::StringRef IRStr{CachedIR.begin(), CachedIR.size()};
318+
std::unique_ptr<llvm::MemoryBuffer> IRBuf =
319+
llvm::MemoryBuffer::getMemBuffer(IRStr, /*BufferName=*/"",
320+
/*RequiresNullTerminator=*/false);
321+
auto ModuleOrError = llvm::parseBitcodeFile(*IRBuf, Context);
322+
if (!ModuleOrError) {
323+
// Not a fatal error, we'll just compile the source string normally.
324+
BuildLog.append(formatError(ModuleOrError.takeError(),
325+
"Loading of cached device code failed"));
326+
} else {
327+
Module = std::move(*ModuleOrError);
328+
}
280329
}
281330

282-
std::unique_ptr<llvm::LLVMContext> Context;
283-
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
284-
Context.reset(&Module->getContext());
331+
bool FromSource = false;
332+
if (!Module) {
333+
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
334+
BuildLog, Context);
335+
if (!ModuleOrErr) {
336+
return errorTo<RTCResult>(ModuleOrErr.takeError(),
337+
"Device compilation failed");
338+
}
339+
340+
Module = std::move(*ModuleOrErr);
341+
FromSource = true;
342+
}
343+
344+
RTCDeviceCodeIR IR;
345+
if (SaveIR && FromSource) {
346+
std::string BCString;
347+
llvm::raw_string_ostream BCStream{BCString};
348+
llvm::WriteBitcodeToFile(*Module, BCStream);
349+
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
350+
}
285351

286352
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
287353
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
@@ -314,7 +380,7 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
314380
}
315381
}
316382

317-
return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
383+
return RTCResult{std::move(BundleInfo), std::move(IR), BuildLog.c_str()};
318384
}
319385

320386
extern "C" KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address) {

0 commit comments

Comments
 (0)