@@ -39,6 +39,9 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute);
3939DECL_LAZY_CUDA_DRIVER_FUNCTION (cuModuleLoad);
4040DECL_LAZY_CUDA_DRIVER_FUNCTION (cuModuleUnload);
4141DECL_LAZY_CUDA_DRIVER_FUNCTION (cuModuleGetFunction);
42+ DECL_LAZY_CUDA_DRIVER_FUNCTION (cuLibraryLoadFromFile);
43+ DECL_LAZY_CUDA_DRIVER_FUNCTION (cuLibraryUnload);
44+ DECL_LAZY_CUDA_DRIVER_FUNCTION (cuKernelGetFunction);
4245DECL_LAZY_CUDA_DRIVER_FUNCTION (cuLaunchKernelEx);
4346DECL_LAZY_CUDA_DRIVER_FUNCTION (cuTensorMapEncodeTiled);
4447
@@ -103,33 +106,57 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle&
103106#else
104107
105108// Use CUDA driver API
106- using LibraryHandle = CUmodule;
107109using KernelHandle = CUfunction;
108110using LaunchConfigHandle = CUlaunchConfig;
109111using LaunchAttrHandle = CUlaunchAttribute;
110112
113+ // `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4
114+ #if CUDA_VERSION >= 12040
115+ #define DG_JIT_USE_LIBRARY_ENUM_KERNELS
116+ DECL_LAZY_CUDA_DRIVER_FUNCTION (cuLibraryGetKernelCount);
117+ DECL_LAZY_CUDA_DRIVER_FUNCTION (cuLibraryEnumerateKernels);
118+ using LibraryHandle = CUlibrary;
119+ #else
120+ using LibraryHandle = CUmodule;
121+ #endif
122+
111123#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK
112124
113125static KernelHandle load_kernel (const std::filesystem::path& cubin_path, const std::string& func_name,
114126 LibraryHandle *library_opt = nullptr ) {
115127 LibraryHandle library;
116128 KernelHandle kernel;
129+
130+ #ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
131+ DG_CUDA_DRIVER_CHECK (lazy_cuLibraryLoadFromFile (&library, cubin_path.c_str (), nullptr , nullptr , 0 , nullptr , nullptr , 0 ));
132+ unsigned int num_kernels;
133+ DG_CUDA_DRIVER_CHECK (lazy_cuLibraryGetKernelCount (&num_kernels, library));
134+ DG_HOST_ASSERT (num_kernels == 1 );
135+ CUkernel cu_kernel;
136+ DG_CUDA_DRIVER_CHECK (lazy_cuLibraryEnumerateKernels (&cu_kernel, 1 , library));
137+ DG_CUDA_DRIVER_CHECK (lazy_cuKernelGetFunction (&kernel, cu_kernel));
138+ #else
117139 DG_CUDA_DRIVER_CHECK (lazy_cuModuleLoad (&library, cubin_path.c_str ()));
118140 DG_CUDA_DRIVER_CHECK (lazy_cuModuleGetFunction (&kernel, library, func_name.c_str ()));
141+ #endif
119142
120143 if (library_opt != nullptr )
121144 *library_opt = library;
122145 return kernel;
123146}
124147
125148static void unload_library (const LibraryHandle& library) {
149+ #ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
150+ const auto error = lazy_cuLibraryUnload (library);
151+ #else
126152 const auto error = lazy_cuModuleUnload (library);
153+ #endif
127154 DG_HOST_ASSERT (error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED);
128155}
129156
130157static LaunchConfigHandle construct_launch_config (const KernelHandle& kernel,
131- const cudaStream_t& stream, const int & smem_size,
132- const dim3& grid_dim, const dim3& block_dim, const int & cluster_dim) {
158+ const cudaStream_t& stream, const int & smem_size,
159+ const dim3& grid_dim, const dim3& block_dim, const int & cluster_dim) {
133160 if (smem_size > 0 )
134161 DG_CUDA_DRIVER_CHECK (lazy_cuFuncSetAttribute (kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
135162
0 commit comments