@@ -140,8 +140,8 @@ class NVCCCompiler final: public Compiler {
140140 DG_HOST_ASSERT (std::regex_search (output, match, std::regex (R"( release (\d+\.\d+))" )));
141141 std::sscanf (match[1 ].str ().c_str (), " %d.%d" , &major, &minor);
142142 DG_HOST_ASSERT ((major > 12 or (major == 12 and minor >= 3 )) and " NVCC version should be >= 12.3" );
143- if (major < 12 or (major == 12 and minor < 9 ) )
144- printf (" Warning: please use at least NVCC 12.9 for the best DeepGEMM performance" );
143+ if (major == 12 and minor < 9 )
144+ printf (" Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n " );
145145 return {major, minor};
146146 }
147147
@@ -155,14 +155,12 @@ class NVCCCompiler final: public Compiler {
155155 signature = fmt::format (" NVCC{}.{}" , nvcc_major, nvcc_minor);
156156
157157 // The override the compiler flags
158- std::string selected_arch = device_runtime->get_arch ();
159- // Compatibility: NVCC < 12.9 may not recognize sm_100f; fallback to sm_100a
160- if (selected_arch == " 100f" && (nvcc_major < 12 || (nvcc_major == 12 && nvcc_minor < 9 )))
161- selected_arch = " 100a" ;
158+ // Only NVCC >= 12.9 supports arch-specific family suffix
159+ const auto & arch = device_runtime->get_arch (false , nvcc_major > 12 or nvcc_minor >= 9 );
162160 flags = fmt::format (" {} -I{} --gpu-architecture=sm_{} "
163161 " --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
164162 " -cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda" ,
165- flags, library_include_path.c_str (), selected_arch );
163+ flags, library_include_path.c_str (), arch );
166164 }
167165
168166 void compile (const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
@@ -193,6 +191,7 @@ class NVRTCCompiler final: public Compiler {
193191 int major, minor;
194192 DG_NVRTC_CHECK (nvrtcVersion (&major, &minor));
195193 signature = fmt::format (" NVRTC{}.{}" , major, minor);
194+ DG_HOST_ASSERT ((major > 12 or (major == 12 and minor >= 3 )) and " NVRTC version should be >= 12.3" );
196195
197196 // Build include directories list
198197 std::string include_dirs;
@@ -202,19 +201,17 @@ class NVRTCCompiler final: public Compiler {
202201 // Add PCH support for version 12.8 and above
203202 // NOTES: PCH is vital for compilation speed
204203 std::string pch_flags;
205- if (major > 12 or (major == 12 and minor >= 8 ) ) {
204+ if (major > 12 or minor >= 8 ) {
206205 pch_flags = " --pch " ;
207206 if (get_env<int >(" DG_JIT_DEBUG" , 0 ))
208207 pch_flags += " --pch-verbose=true " ;
209208 }
210209
211210 // Override the compiler flags
212- std::string selected_arch = device_runtime->get_arch ();
213- // Compatibility: NVRTC < 12.9 may not recognize sm_100f; fallback to sm_100a
214- if (selected_arch == " 100f" && (major < 12 || (major == 12 && minor < 9 )))
215- selected_arch = " 100a" ;
211+ // Only NVRTC >= 12.9 supports arch-specific family suffix
212+ const auto & arch = device_runtime->get_arch (false , major > 12 or minor >= 9 );
216213 flags = fmt::format (" {} {}--gpu-architecture=sm_{} -default-device {}" ,
217- flags, include_dirs, selected_arch , pch_flags);
214+ flags, include_dirs, arch , pch_flags);
218215 }
219216
220217 void compile (const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
0 commit comments