豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 4fdadc8

Browse files
517517517LyricZhao
andauthored
Add support for hc_bwd on SM90 (#170)
* Add support for sm90_hc_bwd * minor fix * Remove all unsafe refs * Remove all unsafe refs * same bar name and add __syncwarp() * additional transpose warp * additional transpose warp * Revert "additional transpose warp" This reverts commit 331df9f893998661002b7cdda968feae2efbdc51. * remove one barrier * improve bf16 to fp32 conversion * fix bug * remove one barrier and postpone wgmma waitgroup<0> * Loop Invariant Code Motion * increase block_M to 128 * Minor Fix * Minor Fix * Add one barrier for stress test correctness * Add comments * Some refactors with the new device APIs --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
1 parent 9f4799d commit 4fdadc8

File tree

6 files changed

+641
-3
lines changed

6 files changed

+641
-3
lines changed

csrc/apis/hyperconnection.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
66
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp"
77
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp"
8+
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_bwd_gemm.hpp" // oss-ignore-line
89
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_bwd_gemm.hpp" // oss-ignore-line
910
#endif
1011

@@ -97,7 +98,9 @@ static void tf32_hc_prenorm_bwd_gemm(const torch::Tensor& a,
9798

9899
// Dispatch into different implements
99100
const auto arch_major = device_runtime->get_arch_major();
100-
if (arch_major == 10) {
101+
if (arch_major == 9) {
102+
sm90_tf32_hc_prenorm_bwd_gemm(a, b, dd, ds, da, db, m, n, k, accumulate_on_da);
103+
} else if (arch_major == 10) {
101104
sm100_tf32_hc_prenorm_bwd_gemm(a, b, dd, ds, da, db, m, n, k, accumulate_on_da);
102105
} else {
103106
DG_HOST_UNREACHABLE("Unsupported architecture");

csrc/indexing/main.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
// Hyperconnection kernels
2525
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
2626
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
27-
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_bwd_gemm.cuh> // oss-ignore-line
27+
/* oss-ignore-begin */
28+
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_bwd_gemm.cuh>
29+
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_bwd_gemm.cuh>
30+
/* oss-ignore-end */
2831

2932
// Layout kernels
3033
#include <deep_gemm/impls/smxx_layout.cuh>
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#pragma once
2+
3+
#include <torch/python.h>
4+
5+
#include "../../jit/compiler.hpp"
6+
#include "../../jit/device_runtime.hpp"
7+
#include "../../jit/kernel_runtime.hpp"
8+
#include "../../utils/exception.hpp"
9+
#include "../../utils/format.hpp"
10+
#include "../../utils/math.hpp"
11+
#include "../heuristics/sm90.hpp"
12+
#include "runtime_utils.hpp"
13+
14+
namespace deep_gemm {
15+
16+
class SM90BF16HCPrenormBwdGemmRuntime final: public LaunchRuntime<SM90BF16HCPrenormBwdGemmRuntime> {
17+
public:
18+
struct Args {
19+
int m, n, k;
20+
int block_m, block_n, block_k;
21+
int swizzle_n_mode, swizzle_k_mode, swizzle_m_mode;
22+
int accumulate_on_da;
23+
int num_stages;
24+
int num_tma_threads, num_da_threads, num_db_threads;
25+
26+
LaunchArgs launch_args;
27+
28+
float* ds;
29+
float* db;
30+
CUtensorMap tensor_map_a;
31+
CUtensorMap tensor_map_b;
32+
CUtensorMap tensor_map_dd;
33+
CUtensorMap tensor_map_da;
34+
};
35+
36+
static std::string generate_impl(const Args& args) {
37+
return fmt::format(R"(
38+
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_bwd_gemm.cuh>
39+
40+
using namespace deep_gemm;
41+
42+
static void __instantiate_kernel() {{
43+
auto ptr = reinterpret_cast<void*>(&sm90_tf32_hc_prenorm_bwd_gemm_impl<
44+
{}, {},
45+
{}, {}, {},
46+
{}, {}, {},
47+
{},
48+
{},
49+
{}, {}, {}
50+
>);
51+
}};
52+
)",
53+
args.n, args.k,
54+
args.block_m, args.block_n, args.block_k,
55+
args.swizzle_n_mode, args.swizzle_k_mode, args.swizzle_m_mode,
56+
args.accumulate_on_da,
57+
args.num_stages,
58+
args.num_tma_threads, args.num_da_threads, args.num_db_threads);
59+
}
60+
61+
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
62+
// TODO: optimize `args` copy
63+
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
64+
args.m, args.ds, args.db,
65+
args.tensor_map_a, args.tensor_map_b, args.tensor_map_dd,
66+
args.tensor_map_da));
67+
}
68+
};
69+
70+
static void sm90_tf32_hc_prenorm_bwd_gemm(const torch::Tensor& a,
71+
const torch::Tensor& b,
72+
const torch::Tensor& dd,
73+
const torch::Tensor& ds,
74+
const torch::Tensor& da,
75+
const torch::Tensor& db,
76+
const int& m, const int& n, const int& k,
77+
const bool& accumulate_on_da) {
78+
constexpr int block_m = 128;
79+
const int block_n = align(n, 16);
80+
DG_HOST_ASSERT(n <= block_n);
81+
// Only support small N for now
82+
DG_HOST_ASSERT(n <= 32 and n % 8 == 0);
83+
constexpr int block_k = 64;
84+
85+
constexpr int num_tma_threads = 128;
86+
constexpr int num_da_threads = 128;
87+
constexpr int num_db_threads = 128;
88+
89+
// NOTES: block K must be large enough (>= 64) to ensure TF32 and BF16 swizzling are the same
90+
const auto& swizzle_n_mode = get_swizzle_mode(block_n, sizeof(float));
91+
const auto& swizzle_k_mode = get_swizzle_mode(block_k, sizeof(float));
92+
const auto& swizzle_m_mode = get_swizzle_mode(block_m, sizeof(float));
93+
DG_HOST_ASSERT(swizzle_k_mode == get_swizzle_mode(block_k, sizeof(nv_bfloat16))); // for tma_a (BF16)
94+
95+
const auto tensor_map_a = make_tma_b_desc(cute::UMMA::Major::MN, a, k, m,
96+
block_k, block_m,
97+
static_cast<int>(a.stride(0)), 1,
98+
swizzle_k_mode);
99+
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, k, n,
100+
block_k, block_n,
101+
static_cast<int>(b.stride(0)), 1,
102+
swizzle_k_mode, 0, true);
103+
const auto tensor_map_dd = make_tma_a_desc(cute::UMMA::Major::MN, dd, n, m,
104+
block_n, block_m,
105+
static_cast<int>(dd.stride(0)), 1,
106+
swizzle_n_mode, 0, true);
107+
const auto tensor_map_da = make_tma_cd_desc(da, m, k, // (m, k) k inner major
108+
block_m, block_k,
109+
static_cast<int>(da.stride(0)), 1,
110+
swizzle_k_mode);
111+
112+
// Calculate stages
113+
int num_stages = 16, smem_size = 0;
114+
while (num_stages > 0) {
115+
const int smem_dd_per_stage = block_m * block_n * sizeof(float);
116+
const int smem_a_per_stage = block_m * block_k * sizeof(nv_bfloat16);
117+
const int smem_b = block_n * block_k * sizeof(float);
118+
const int smem_ds_per_stage = block_m * sizeof(float);
119+
const int smem_da = block_m * block_k * sizeof(nv_bfloat16) * 2;
120+
const int smem_barriers = (num_stages * 4 + 1) * 8;
121+
smem_size = (smem_dd_per_stage + smem_a_per_stage + smem_ds_per_stage) * num_stages +
122+
smem_da + smem_b + smem_barriers;
123+
124+
if (smem_size <= SM90ArchSpec::smem_capacity)
125+
break;
126+
-- num_stages;
127+
}
128+
DG_HOST_ASSERT(num_stages > 0);
129+
130+
// Print configs
131+
if (get_env("DG_JIT_DEBUG", 0)) {
132+
printf("M: %d, N: %d, K: %d -> "
133+
"block M: %d, block N: %d, block K: %d, "
134+
"stages: %d, shared memory: %d, "
135+
"swizzle N: %d, swizzle K: %d\n, swizzle M: %d\n",
136+
m, n, k, block_m, block_n, block_k,
137+
num_stages, smem_size,
138+
swizzle_n_mode, swizzle_k_mode, swizzle_m_mode);
139+
}
140+
141+
// Launch
142+
const SM90BF16HCPrenormBwdGemmRuntime::Args& args = {
143+
.m = m, .n = n, .k = k,
144+
.block_m = block_m, .block_n = block_n, .block_k = block_k,
145+
.swizzle_n_mode = swizzle_n_mode, .swizzle_k_mode = swizzle_k_mode, .swizzle_m_mode = swizzle_m_mode,
146+
.accumulate_on_da = static_cast<int>(accumulate_on_da),
147+
.num_stages = num_stages,
148+
.num_tma_threads = num_tma_threads,
149+
.num_da_threads = num_da_threads,
150+
.num_db_threads = num_db_threads,
151+
.launch_args = LaunchArgs(ceil_div(k, block_k), num_tma_threads + num_da_threads + num_db_threads, smem_size, 1),
152+
.ds = ds.data_ptr<float>(),
153+
.db = db.data_ptr<float>(),
154+
.tensor_map_a = tensor_map_a,
155+
.tensor_map_b = tensor_map_b,
156+
.tensor_map_dd = tensor_map_dd,
157+
.tensor_map_da = tensor_map_da,
158+
};
159+
const auto code = SM90BF16HCPrenormBwdGemmRuntime::generate(args);
160+
const auto runtime = compiler->build("sm90_tf32_hc_prenorm_bwd_gemm", code);
161+
SM90BF16HCPrenormBwdGemmRuntime::launch(runtime, args);
162+
}
163+
164+
} // namespace deep_gemm

deep_gemm/include/deep_gemm/common/math.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ __forceinline__ __device__ void swap(T& a, T& b) {
4545
b = temp;
4646
}
4747

48+
__device__ __forceinline__ float2 fma2(const float2& a, const float2& b, const float2& c) {
49+
return make_float2(
50+
__fmaf_rn(a.x, b.x, c.x),
51+
__fmaf_rn(a.y, b.y, c.y)
52+
);
53+
}
54+
4855
/// Casting
4956
template <typename old_t>
5057
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {

0 commit comments

Comments
 (0)