Skip to content

Commit 6e2f08b

Browse files
authored
[Clang][NFC] Clean up fetching the offloading toolchain (#125095)
Summary: This patch cleans up how we query the offloading toolchain. We create a single that is more similar to the existing `getToolChain` driver function and make all the offloading handlers use it.
1 parent 8d925a1 commit 6e2f08b

File tree

4 files changed

+105
-113
lines changed

4 files changed

+105
-113
lines changed

clang/include/clang/Driver/Driver.h

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -797,22 +797,14 @@ class Driver {
797797
const ToolChain &getToolChain(const llvm::opt::ArgList &Args,
798798
const llvm::Triple &Target) const;
799799

800-
/// @}
801-
802-
/// Retrieves a ToolChain for a particular device \p Target triple
803-
///
804-
/// \param[in] HostTC is the host ToolChain paired with the device
805-
///
806-
/// \param[in] TargetDeviceOffloadKind (e.g. OFK_Cuda/OFK_OpenMP/OFK_SYCL) is
807-
/// an Offloading action that is optionally passed to a ToolChain (used by
808-
/// CUDA, to specify if it's used in conjunction with OpenMP)
800+
/// Retrieves a ToolChain for a particular \p Target triple for offloading.
809801
///
810802
/// Will cache ToolChains for the life of the driver object, and create them
811803
/// on-demand.
812-
const ToolChain &getOffloadingDeviceToolChain(
813-
const llvm::opt::ArgList &Args, const llvm::Triple &Target,
814-
const ToolChain &HostTC,
815-
const Action::OffloadKind &TargetDeviceOffloadKind) const;
804+
const ToolChain &getOffloadToolChain(const llvm::opt::ArgList &Args,
805+
const Action::OffloadKind Kind,
806+
const llvm::Triple &Target,
807+
const llvm::Triple &AuxTarget) const;
816808

817809
/// Get bitmasks for which option flags to include and exclude based on
818810
/// the driver mode.

clang/lib/Driver/Driver.cpp

Lines changed: 98 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -886,42 +886,37 @@ void Driver::CreateOffloadingDeviceToolChains(Compilation &C,
886886
return;
887887
}
888888
if (IsCuda && !UseLLVMOffload) {
889-
const ToolChain *HostTC = C.getSingleOffloadToolChain<Action::OFK_Host>();
890-
const llvm::Triple &HostTriple = HostTC->getTriple();
891-
auto OFK = Action::OFK_Cuda;
892-
auto CudaTriple =
893-
getNVIDIAOffloadTargetTriple(*this, C.getInputArgs(), HostTriple);
889+
auto CudaTriple = getNVIDIAOffloadTargetTriple(
890+
*this, C.getInputArgs(), C.getDefaultToolChain().getTriple());
894891
if (!CudaTriple)
895892
return;
896-
// Use the CUDA and host triples as the key into the ToolChains map,
897-
// because the device toolchain we create depends on both.
898-
auto &CudaTC = ToolChains[CudaTriple->str() + "/" + HostTriple.str()];
899-
if (!CudaTC) {
900-
CudaTC = std::make_unique<toolchains::CudaToolChain>(
901-
*this, *CudaTriple, *HostTC, C.getInputArgs());
902-
903-
// Emit a warning if the detected CUDA version is too new.
904-
CudaInstallationDetector &CudaInstallation =
905-
static_cast<toolchains::CudaToolChain &>(*CudaTC).CudaInstallation;
906-
if (CudaInstallation.isValid())
907-
CudaInstallation.WarnIfUnsupportedVersion();
908-
}
909-
C.addOffloadDeviceToolChain(CudaTC.get(), OFK);
893+
894+
auto &TC =
895+
getOffloadToolChain(C.getInputArgs(), Action::OFK_Cuda, *CudaTriple,
896+
C.getDefaultToolChain().getTriple());
897+
898+
// Emit a warning if the detected CUDA version is too new.
899+
const CudaInstallationDetector &CudaInstallation =
900+
static_cast<const toolchains::CudaToolChain &>(TC).CudaInstallation;
901+
if (CudaInstallation.isValid())
902+
CudaInstallation.WarnIfUnsupportedVersion();
903+
C.addOffloadDeviceToolChain(&TC, Action::OFK_Cuda);
910904
} else if (IsHIP && !UseLLVMOffload) {
911905
if (auto *OMPTargetArg =
912906
C.getInputArgs().getLastArg(options::OPT_fopenmp_targets_EQ)) {
913907
Diag(clang::diag::err_drv_unsupported_opt_for_language_mode)
914908
<< OMPTargetArg->getSpelling() << "HIP";
915909
return;
916910
}
917-
const ToolChain *HostTC = C.getSingleOffloadToolChain<Action::OFK_Host>();
918-
auto OFK = Action::OFK_HIP;
911+
919912
auto HIPTriple = getHIPOffloadTargetTriple(*this, C.getInputArgs());
920913
if (!HIPTriple)
921914
return;
922-
auto *HIPTC = &getOffloadingDeviceToolChain(C.getInputArgs(), *HIPTriple,
923-
*HostTC, OFK);
924-
C.addOffloadDeviceToolChain(HIPTC, OFK);
915+
916+
auto &TC =
917+
getOffloadToolChain(C.getInputArgs(), Action::OFK_HIP, *HIPTriple,
918+
C.getDefaultToolChain().getTriple());
919+
C.addOffloadDeviceToolChain(&TC, Action::OFK_HIP);
925920
}
926921

927922
if (IsCuda || IsHIP)
@@ -1038,40 +1033,17 @@ void Driver::CreateOffloadingDeviceToolChains(Compilation &C,
10381033
FoundNormalizedTriples[NormalizedName] = Val;
10391034

10401035
// If the specified target is invalid, emit a diagnostic.
1041-
if (TT.getArch() == llvm::Triple::UnknownArch)
1036+
if (TT.getArch() == llvm::Triple::UnknownArch) {
10421037
Diag(clang::diag::err_drv_invalid_omp_target) << Val;
1043-
else {
1044-
const ToolChain *TC;
1045-
// Device toolchains have to be selected differently. They pair host
1046-
// and device in their implementation.
1047-
if (TT.isNVPTX() || TT.isAMDGCN() || TT.isSPIRV()) {
1048-
const ToolChain *HostTC =
1049-
C.getSingleOffloadToolChain<Action::OFK_Host>();
1050-
assert(HostTC && "Host toolchain should be always defined.");
1051-
auto &DeviceTC =
1052-
ToolChains[TT.str() + "/" + HostTC->getTriple().normalize()];
1053-
if (!DeviceTC) {
1054-
if (TT.isNVPTX())
1055-
DeviceTC = std::make_unique<toolchains::CudaToolChain>(
1056-
*this, TT, *HostTC, C.getInputArgs());
1057-
else if (TT.isAMDGCN())
1058-
DeviceTC = std::make_unique<toolchains::AMDGPUOpenMPToolChain>(
1059-
*this, TT, *HostTC, C.getInputArgs());
1060-
else if (TT.isSPIRV())
1061-
DeviceTC = std::make_unique<toolchains::SPIRVOpenMPToolChain>(
1062-
*this, TT, *HostTC, C.getInputArgs());
1063-
else
1064-
assert(DeviceTC && "Device toolchain not defined.");
1065-
}
1066-
1067-
TC = DeviceTC.get();
1068-
} else
1069-
TC = &getToolChain(C.getInputArgs(), TT);
1070-
C.addOffloadDeviceToolChain(TC, Action::OFK_OpenMP);
1071-
auto It = DerivedArchs.find(TT.getTriple());
1072-
if (It != DerivedArchs.end())
1073-
KnownArchs[TC] = It->second;
1038+
continue;
10741039
}
1040+
1041+
auto &TC = getOffloadToolChain(C.getInputArgs(), Action::OFK_OpenMP, TT,
1042+
C.getDefaultToolChain().getTriple());
1043+
C.addOffloadDeviceToolChain(&TC, Action::OFK_OpenMP);
1044+
auto It = DerivedArchs.find(TT.getTriple());
1045+
if (It != DerivedArchs.end())
1046+
KnownArchs[&TC] = It->second;
10751047
}
10761048
} else if (C.getInputArgs().hasArg(options::OPT_fopenmp_targets_EQ)) {
10771049
Diag(clang::diag::err_drv_expecting_fopenmp_with_fopenmp_targets);
@@ -1103,9 +1075,9 @@ void Driver::CreateOffloadingDeviceToolChains(Compilation &C,
11031075
// getOffloadingDeviceToolChain, because the device toolchains we're
11041076
// going to create will depend on both.
11051077
const ToolChain *HostTC = C.getSingleOffloadToolChain<Action::OFK_Host>();
1106-
for (const auto &TargetTriple : UniqueSYCLTriplesVec) {
1107-
auto SYCLTC = &getOffloadingDeviceToolChain(
1108-
C.getInputArgs(), TargetTriple, *HostTC, Action::OFK_SYCL);
1078+
for (const auto &TT : UniqueSYCLTriplesVec) {
1079+
auto SYCLTC = &getOffloadToolChain(C.getInputArgs(), Action::OFK_SYCL, TT,
1080+
HostTC->getTriple());
11091081
C.addOffloadDeviceToolChain(SYCLTC, Action::OFK_SYCL);
11101082
}
11111083
}
@@ -6605,6 +6577,73 @@ std::string Driver::GetClPchPath(Compilation &C, StringRef BaseName) const {
66056577
return std::string(Output);
66066578
}
66076579

6580+
const ToolChain &Driver::getOffloadToolChain(
6581+
const llvm::opt::ArgList &Args, const Action::OffloadKind Kind,
6582+
const llvm::Triple &Target, const llvm::Triple &AuxTarget) const {
6583+
std::unique_ptr<ToolChain> &TC =
6584+
ToolChains[Target.str() + "/" + AuxTarget.str()];
6585+
std::unique_ptr<ToolChain> &HostTC = ToolChains[AuxTarget.str()];
6586+
6587+
assert(HostTC && "Host toolchain for offloading doesn't exit?");
6588+
if (!TC) {
6589+
// Detect the toolchain based off of the target operating system.
6590+
switch (Target.getOS()) {
6591+
case llvm::Triple::CUDA:
6592+
TC = std::make_unique<toolchains::CudaToolChain>(*this, Target, *HostTC,
6593+
Args);
6594+
break;
6595+
case llvm::Triple::AMDHSA:
6596+
if (Kind == Action::OFK_HIP)
6597+
TC = std::make_unique<toolchains::HIPAMDToolChain>(*this, Target,
6598+
*HostTC, Args);
6599+
else if (Kind == Action::OFK_OpenMP)
6600+
TC = std::make_unique<toolchains::AMDGPUOpenMPToolChain>(*this, Target,
6601+
*HostTC, Args);
6602+
break;
6603+
default:
6604+
break;
6605+
}
6606+
}
6607+
if (!TC) {
6608+
// Detect the toolchain based off of the target architecture if that failed.
6609+
switch (Target.getArch()) {
6610+
case llvm::Triple::spir:
6611+
case llvm::Triple::spir64:
6612+
case llvm::Triple::spirv:
6613+
case llvm::Triple::spirv32:
6614+
case llvm::Triple::spirv64:
6615+
switch (Kind) {
6616+
case Action::OFK_SYCL:
6617+
TC = std::make_unique<toolchains::SYCLToolChain>(*this, Target, *HostTC,
6618+
Args);
6619+
break;
6620+
case Action::OFK_HIP:
6621+
TC = std::make_unique<toolchains::HIPSPVToolChain>(*this, Target,
6622+
*HostTC, Args);
6623+
break;
6624+
case Action::OFK_OpenMP:
6625+
TC = std::make_unique<toolchains::SPIRVOpenMPToolChain>(*this, Target,
6626+
*HostTC, Args);
6627+
break;
6628+
case Action::OFK_Cuda:
6629+
TC = std::make_unique<toolchains::CudaToolChain>(*this, Target, *HostTC,
6630+
Args);
6631+
break;
6632+
default:
6633+
break;
6634+
}
6635+
break;
6636+
default:
6637+
break;
6638+
}
6639+
}
6640+
6641+
// If all else fails, just look up the normal toolchain for the target.
6642+
if (!TC)
6643+
return getToolChain(Args, Target);
6644+
return *TC;
6645+
}
6646+
66086647
const ToolChain &Driver::getToolChain(const ArgList &Args,
66096648
const llvm::Triple &Target) const {
66106649

@@ -6798,45 +6837,6 @@ const ToolChain &Driver::getToolChain(const ArgList &Args,
67986837
return *TC;
67996838
}
68006839

6801-
const ToolChain &Driver::getOffloadingDeviceToolChain(
6802-
const ArgList &Args, const llvm::Triple &Target, const ToolChain &HostTC,
6803-
const Action::OffloadKind &TargetDeviceOffloadKind) const {
6804-
// Use device / host triples as the key into the ToolChains map because the
6805-
// device ToolChain we create depends on both.
6806-
auto &TC = ToolChains[Target.str() + "/" + HostTC.getTriple().str()];
6807-
if (!TC) {
6808-
// Categorized by offload kind > arch rather than OS > arch like
6809-
// the normal getToolChain call, as it seems a reasonable way to categorize
6810-
// things.
6811-
switch (TargetDeviceOffloadKind) {
6812-
case Action::OFK_HIP: {
6813-
if (((Target.getArch() == llvm::Triple::amdgcn ||
6814-
Target.getArch() == llvm::Triple::spirv64) &&
6815-
Target.getVendor() == llvm::Triple::AMD &&
6816-
Target.getOS() == llvm::Triple::AMDHSA) ||
6817-
!Args.hasArgNoClaim(options::OPT_offload_EQ))
6818-
TC = std::make_unique<toolchains::HIPAMDToolChain>(*this, Target,
6819-
HostTC, Args);
6820-
else if (Target.getArch() == llvm::Triple::spirv64 &&
6821-
Target.getVendor() == llvm::Triple::UnknownVendor &&
6822-
Target.getOS() == llvm::Triple::UnknownOS)
6823-
TC = std::make_unique<toolchains::HIPSPVToolChain>(*this, Target,
6824-
HostTC, Args);
6825-
break;
6826-
}
6827-
case Action::OFK_SYCL:
6828-
if (Target.isSPIROrSPIRV())
6829-
TC = std::make_unique<toolchains::SYCLToolChain>(*this, Target, HostTC,
6830-
Args);
6831-
break;
6832-
default:
6833-
break;
6834-
}
6835-
}
6836-
assert(TC && "Could not create offloading device tool chain.");
6837-
return *TC;
6838-
}
6839-
68406840
bool Driver::ShouldUseClangCompiler(const JobAction &JA) const {
68416841
// Say "no" if there is not exactly one input of a type clang understands.
68426842
if (JA.size() != 1 ||

clang/lib/Driver/ToolChains/Cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ CudaVersion parseCudaHFile(llvm::StringRef Input) {
123123
}
124124
} // namespace
125125

126-
void CudaInstallationDetector::WarnIfUnsupportedVersion() {
126+
void CudaInstallationDetector::WarnIfUnsupportedVersion() const {
127127
if (Version > CudaVersion::PARTIALLY_SUPPORTED) {
128128
std::string VersionString = CudaVersionToString(Version);
129129
if (!VersionString.empty())

clang/lib/Driver/ToolChains/Cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class CudaInstallationDetector {
7474
std::string getLibDeviceFile(StringRef Gpu) const {
7575
return LibDeviceMap.lookup(Gpu);
7676
}
77-
void WarnIfUnsupportedVersion();
77+
void WarnIfUnsupportedVersion() const;
7878
};
7979

8080
namespace tools {

0 commit comments

Comments
 (0)