豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 9f4799d

Browse files
authored
Loading with the Library Enumerate Kernels API (#181)
* Support Driver Library Enumerate Kernels * Minor fix * Many fixes * Lint * Minor fix
1 parent 88398ce commit 9f4799d

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

csrc/jit/handle.hpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute);
3939
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad);
4040
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload);
4141
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction);
42+
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryLoadFromFile);
43+
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryUnload);
44+
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction);
4245
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
4346
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);
4447

@@ -103,33 +106,57 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle&
103106
#else
104107

105108
// Use CUDA driver API
106-
using LibraryHandle = CUmodule;
107109
using KernelHandle = CUfunction;
108110
using LaunchConfigHandle = CUlaunchConfig;
109111
using LaunchAttrHandle = CUlaunchAttribute;
110112

113+
// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4
114+
#if CUDA_VERSION >= 12040
115+
#define DG_JIT_USE_LIBRARY_ENUM_KERNELS
116+
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount);
117+
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels);
118+
using LibraryHandle = CUlibrary;
119+
#else
120+
using LibraryHandle = CUmodule;
121+
#endif
122+
111123
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK
112124

113125
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
114126
LibraryHandle *library_opt = nullptr) {
115127
LibraryHandle library;
116128
KernelHandle kernel;
129+
130+
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
131+
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
132+
unsigned int num_kernels;
133+
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryGetKernelCount(&num_kernels, library));
134+
DG_HOST_ASSERT(num_kernels == 1);
135+
CUkernel cu_kernel;
136+
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryEnumerateKernels(&cu_kernel, 1, library));
137+
DG_CUDA_DRIVER_CHECK(lazy_cuKernelGetFunction(&kernel, cu_kernel));
138+
#else
117139
DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str()));
118140
DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str()));
141+
#endif
119142

120143
if (library_opt != nullptr)
121144
*library_opt = library;
122145
return kernel;
123146
}
124147

125148
static void unload_library(const LibraryHandle& library) {
149+
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
150+
const auto error = lazy_cuLibraryUnload(library);
151+
#else
126152
const auto error = lazy_cuModuleUnload(library);
153+
#endif
127154
DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED);
128155
}
129156

130157
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
131-
const cudaStream_t& stream, const int& smem_size,
132-
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
158+
const cudaStream_t& stream, const int& smem_size,
159+
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
133160
if (smem_size > 0)
134161
DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
135162

csrc/jit/kernel_runtime.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class KernelRuntime final {
4646
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME"))
4747
start_time = std::chrono::high_resolution_clock::now();
4848

49+
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
50+
// Load from the library
51+
kernel = load_kernel(cubin_path, {}, &library);
52+
#else
4953
// Find the only symbol
5054
// TODO: use kernel enumeration for newer drivers
5155
const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"};
@@ -75,6 +79,7 @@ class KernelRuntime final {
7579

7680
// Load from the library
7781
kernel = load_kernel(cubin_path, symbol_names[0], &library);
82+
#endif
7883

7984
// Print load time
8085
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME")) {

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,16 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
7272
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
7373
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
7474
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
75+
#if CUDA_VERSION >= 12080
7576
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
7677
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
78+
#endif
7779
default: DG_HOST_UNREACHABLE("Unsupported dtype");
7880
}
7981
}
8082

8183
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
82-
#if CUDART_VERSION >= 12080
84+
#if CUDA_VERSION >= 12080
8385
if (base != 0) {
8486
DG_HOST_ASSERT(base == 32 and mode == 128);
8587
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;

0 commit comments

Comments
 (0)