Skip to content

Commit 7b47ebb

Browse files
authored
[SYCL] Move bfloat support from experimental to supported. (#6524)
This change makes bfloat16 a supported feature. Signed-off-by: Rajiv Deodhar <[email protected]> Signed-off-by: JackAKirk <[email protected]>
1 parent 67f6bba commit 7b47ebb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1554
-815
lines changed

clang/lib/Basic/Targets/NVPTX.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ void NVPTXTargetInfo::getTargetDefines(const LangOptions &Opts,
168168
MacroBuilder &Builder) const {
169169
Builder.defineMacro("__PTX__");
170170
Builder.defineMacro("__NVPTX__");
171-
if (Opts.CUDAIsDevice || Opts.OpenMPIsDevice) {
171+
if (Opts.CUDAIsDevice || Opts.OpenMPIsDevice || Opts.SYCLIsDevice) {
172172
// Set __CUDA_ARCH__ for the GPU specified.
173173
std::string CUDAArchCode = [this] {
174174
switch (GPU) {

clang/lib/Driver/Driver.cpp

+87
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
#include <cstdlib> // ::getenv
102102
#include <map>
103103
#include <memory>
104+
#include <regex>
105+
#include <sstream>
104106
#include <utility>
105107
#if LLVM_ON_UNIX
106108
#include <unistd.h> // getpid
@@ -5064,6 +5066,76 @@ class OffloadingActionBuilder final {
50645066
}
50655067
}
50665068

5069+
// Return whether to use native bfloat16 library.
5070+
bool selectBfloatLibs(const ToolChain *TC, bool &useNative) {
5071+
const OptTable &Opts = C.getDriver().getOpts();
5072+
const char *TargetOpt = nullptr;
5073+
const char *DeviceOpt = nullptr;
5074+
bool needLibs = false;
5075+
for (auto *A : Args) {
5076+
llvm::Triple *TargetBE = nullptr;
5077+
5078+
auto GetTripleIt = [&, this](llvm::StringRef Triple) {
5079+
llvm::Triple TargetTriple{Triple};
5080+
auto TripleIt = llvm::find_if(SYCLTripleList, [&](auto &SYCLTriple) {
5081+
return SYCLTriple == TargetTriple;
5082+
});
5083+
return TripleIt != SYCLTripleList.end() ? &*TripleIt : nullptr;
5084+
};
5085+
5086+
if (A->getOption().matches(options::OPT_fsycl_targets_EQ)) {
5087+
// spir64 target is actually JIT compilation, so we defer selection of
5088+
// bfloat16 libraries to runtime. For AOT we need libraries.
5089+
needLibs = TC->getTriple().getSubArch() != llvm::Triple::NoSubArch;
5090+
TargetBE = GetTripleIt(A->getValue(0));
5091+
if (TargetBE)
5092+
TargetOpt = A->getValue(0);
5093+
else
5094+
continue;
5095+
} else if (A->getOption().matches(options::OPT_Xsycl_backend_EQ)) {
5096+
// Passing device args: -Xsycl-target-backend=<triple> <opt>
5097+
TargetBE = GetTripleIt(A->getValue(0));
5098+
if (TargetBE)
5099+
DeviceOpt = A->getValue(1);
5100+
else
5101+
continue;
5102+
} else if (A->getOption().matches(options::OPT_Xsycl_backend)) {
5103+
// Passing device args: -Xsycl-target-backend <opt>
5104+
TargetBE = &SYCLTripleList.front();
5105+
DeviceOpt = A->getValue(0);
5106+
} else if (A->getOption().matches(options::OPT_Xs_separate)) {
5107+
// Passing device args: -Xs <opt>
5108+
DeviceOpt = A->getValue(0);
5109+
} else {
5110+
continue;
5111+
};
5112+
}
5113+
useNative = false;
5114+
if (needLibs)
5115+
if (TC->getTriple().getSubArch() == llvm::Triple::SPIRSubArch_gen &&
5116+
TargetOpt && DeviceOpt) {
5117+
5118+
auto checkBF = [=](std::string &Dev) {
5119+
static const std::regex BFFs("pvc.*|ats.*");
5120+
return std::regex_match(Dev, BFFs);
5121+
};
5122+
5123+
needLibs = true;
5124+
std::string Params{DeviceOpt};
5125+
size_t DevicesPos = Params.find("-device ");
5126+
useNative = false;
5127+
if (DevicesPos != std::string::npos) {
5128+
useNative = true;
5129+
std::istringstream Devices(Params.substr(DevicesPos + 8));
5130+
for (std::string S; std::getline(Devices, S, ',');) {
5131+
useNative &= checkBF(S);
5132+
}
5133+
}
5134+
}
5135+
5136+
return needLibs;
5137+
}
5138+
50675139
bool addSYCLDeviceLibs(const ToolChain *TC, ActionList &DeviceLinkObjects,
50685140
bool isSpirvAOT, bool isMSVCEnv) {
50695141
struct DeviceLibOptInfo {
@@ -5139,6 +5211,10 @@ class OffloadingActionBuilder final {
51395211
{"libsycl-fallback-imf", "libimf-fp32"},
51405212
{"libsycl-fallback-imf-fp64", "libimf-fp64"},
51415213
{"libsycl-fallback-imf-bf16", "libimf-bf16"}};
5214+
const SYCLDeviceLibsList sycl_device_bfloat16_fallback_lib = {
5215+
{"libsycl-fallback-bfloat16", "libm-bfloat16"}};
5216+
const SYCLDeviceLibsList sycl_device_bfloat16_native_lib = {
5217+
{"libsycl-native-bfloat16", "libm-bfloat16"}};
51425218
// ITT annotation libraries are linked in separately whenever the device
51435219
// code instrumentation is enabled.
51445220
const SYCLDeviceLibsList sycl_device_annotation_libs = {
@@ -5188,6 +5264,17 @@ class OffloadingActionBuilder final {
51885264
addInputs(sycl_device_wrapper_libs);
51895265
if (isSpirvAOT || TC->getTriple().isNVPTX())
51905266
addInputs(sycl_device_fallback_libs);
5267+
5268+
bool nativeBfloatLibs;
5269+
bool needBfloatLibs = selectBfloatLibs(TC, nativeBfloatLibs);
5270+
if (needBfloatLibs) {
5271+
// Add native or fallback bfloat16 library.
5272+
if (nativeBfloatLibs)
5273+
addInputs(sycl_device_bfloat16_native_lib);
5274+
else
5275+
addInputs(sycl_device_bfloat16_fallback_lib);
5276+
}
5277+
51915278
if (Args.hasFlag(options::OPT_fsycl_instrument_device_code,
51925279
options::OPT_fno_sycl_instrument_device_code, true))
51935280
addInputs(sycl_device_annotation_libs);

0 commit comments

Comments
 (0)