豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 0e49c33

Browse files
committed
Refactor compiler version checks and arch flags
1 parent 3a93f4e commit 0e49c33

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

csrc/jit/compiler.hpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ class NVCCCompiler final: public Compiler {
140140
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
141141
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
142142
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
143-
if (major < 12 or (major == 12 and minor < 9))
144-
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance");
143+
if (major == 12 and minor < 9)
144+
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n");
145145
return {major, minor};
146146
}
147147

@@ -155,14 +155,12 @@ class NVCCCompiler final: public Compiler {
155155
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
156156

157157
// The override the compiler flags
158-
std::string selected_arch = device_runtime->get_arch();
159-
// Compatibility: NVCC < 12.9 may not recognize sm_100f; fallback to sm_100a
160-
if (selected_arch == "100f" && (nvcc_major < 12 || (nvcc_major == 12 && nvcc_minor < 9)))
161-
selected_arch = "100a";
158+
// Only NVCC >= 12.9 supports arch-specific family suffix
159+
const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
162160
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
163161
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
164162
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
165-
flags, library_include_path.c_str(), selected_arch);
163+
flags, library_include_path.c_str(), arch);
166164
}
167165

168166
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
@@ -193,6 +191,7 @@ class NVRTCCompiler final: public Compiler {
193191
int major, minor;
194192
DG_NVRTC_CHECK(nvrtcVersion(&major, &minor));
195193
signature = fmt::format("NVRTC{}.{}", major, minor);
194+
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3");
196195

197196
// Build include directories list
198197
std::string include_dirs;
@@ -202,19 +201,17 @@ class NVRTCCompiler final: public Compiler {
202201
// Add PCH support for version 12.8 and above
203202
// NOTES: PCH is vital for compilation speed
204203
std::string pch_flags;
205-
if (major > 12 or (major == 12 and minor >= 8)) {
204+
if (major > 12 or minor >= 8) {
206205
pch_flags = "--pch ";
207206
if (get_env<int>("DG_JIT_DEBUG", 0))
208207
pch_flags += "--pch-verbose=true ";
209208
}
210209

211210
// Override the compiler flags
212-
std::string selected_arch = device_runtime->get_arch();
213-
// Compatibility: NVRTC < 12.9 may not recognize sm_100f; fallback to sm_100a
214-
if (selected_arch == "100f" && (major < 12 || (major == 12 && minor < 9)))
215-
selected_arch = "100a";
211+
// Only NVRTC >= 12.9 supports arch-specific family suffix
212+
const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
216213
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
217-
flags, include_dirs, selected_arch, pch_flags);
214+
flags, include_dirs, arch, pch_flags);
218215
}
219216

220217
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {

csrc/jit/device_runtime.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ class DeviceRuntime {
2525
return {prop->major, prop->minor};
2626
}
2727

28-
std::string get_arch() {
28+
std::string get_arch(const bool& number_only = false,
29+
const bool& support_arch_family = false) {
2930
const auto& [major, minor] = get_arch_pair();
30-
if (major == 10 and minor != 1)
31-
return "100f";
32-
return std::to_string(major * 10 + minor) + "a";
31+
if (major == 10 and minor != 1) {
32+
if (number_only)
33+
return "100";
34+
return support_arch_family ? "100f" : "100a";
35+
}
36+
return std::to_string(major * 10 + minor) + (number_only ? "" : "a");
3337
}
3438

3539
int get_arch_major() {

0 commit comments

Comments
 (0)