豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit 9b6dccd

Browse files
Misc optimizations (#240)
* Remove unnecessary cute::min * Use add.rn.f32.bf16 for mixed-precision addition * Code tidy-up * Use __fdividef for kFastMath * Revert buggy optimization
1 parent 2b71c00 commit 9b6dccd

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,6 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
589589
// Wait for token TMA store to complete
590590
cute::tma_store_arrive();
591591
ptx::tma_store_wait<0>();
592-
}
593-
__syncwarp();
594-
595-
// Notify finishing
596-
if (cute::elect_one_sync()) {
597592
ptx::red_add_rel(
598593
workspace.get_l1_arrival_count_ptr(current_expert_idx, token_idx_in_expert / BLOCK_M), 1);
599594
}
@@ -1025,7 +1020,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
10251020
const auto denom = __fadd2_rn(make_float2(1.0f, 1.0f), neg_gate_exp);
10261021
float2 silu_gate;
10271022
if constexpr (kFastMath) {
1028-
silu_gate = make_float2(__fdiv_rn(gate.x, denom.x), __fdiv_rn(gate.y, denom.y));
1023+
silu_gate = make_float2(__fdividef(gate.x, denom.x), __fdividef(gate.y, denom.y));
10291024
} else {
10301025
silu_gate = make_float2(gate.x / denom.x, gate.y / denom.y);
10311026
}
@@ -1349,7 +1344,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
13491344
const auto bf16_values = reinterpret_cast<const nv_bfloat162*>(&uint4_values);
13501345
#pragma unroll
13511346
for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l)
1352-
reduced[j * kNumElemsPerUint4 + l] = __fadd2_rn(reduced[j * kNumElemsPerUint4 + l], __bfloat1622float2(bf16_values[l]));
1347+
ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]);
13531348
}
13541349
combine_phase ^= load_stage_idx;
13551350
load_stage_idx ^= 1;

deep_gemm/include/deep_gemm/ptx/utils.cuh

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <cuda/std/cstdint>
4+
#include <cuda_bf16.h>
45

56
#include <deep_gemm/common/exception.cuh>
67

@@ -14,14 +15,9 @@ CUTLASS_DEVICE uint32_t get_sm_idx() {
1415

1516
CUTLASS_DEVICE uint32_t get_lane_idx() {
1617
uint32_t lane_id;
17-
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
18+
asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id));
1819
return lane_id;
1920
}
20-
__forceinline__ __device__ float warp_reduce_amax(const float& value, const uint32_t& mask) {
21-
float result;
22-
asm volatile("redux.sync.max.abs.NaN.f32 %0, %1, %d;\n" : "=f"(result) : "f"(value), "r"(mask));
23-
return result;
24-
}
2521

2622
template <typename dtype_t>
2723
CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) {
@@ -35,4 +31,15 @@ CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) {
3531
return recv_dtype;
3632
}
3733

34+
CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) {
35+
#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
36+
// Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100
37+
asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast<uint16_t*>(&b.x)));
38+
asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast<uint16_t*>(&b.y)));
39+
#else
40+
const auto [x, y] = __bfloat1622float2(b);
41+
a.x += x, a.y += y;
42+
#endif
43+
}
44+
3845
} // namespace deep_gemm::ptx

0 commit comments

Comments
 (0)