豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim / 2},
{kv_cache_stride_bytes, head_dim / 2, 1},
{kv_cache_stride_bytes, fp4_with_sf_bytes, 1},
torch::TensorOptions().dtype(kPackedFP4)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim / 2,
fused_kv_cache.data_ptr<uint8_t>() + head_dim / 2,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(int)), 1},
{kv_cache_stride_bytes / static_cast<int>(sizeof(int)), fp4_with_sf_bytes / static_cast<int>(sizeof(int))},
torch::TensorOptions().dtype(torch::kInt32)
);
} else {
Expand All @@ -295,13 +295,13 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, st
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim},
{kv_cache_stride_bytes, head_dim, 1},
{kv_cache_stride_bytes, head_dim_with_sf, 1},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
fused_kv_cache.data_ptr<uint8_t>() + head_dim,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), head_dim_with_sf / static_cast<int>(sizeof(float))},
torch::TensorOptions().dtype(torch::kFloat32)
);

Expand Down
20 changes: 12 additions & 8 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,25 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_cast_back = x_scaled.float() * sf

x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8)
x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(torch.uint8)
x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), x_cast_back.to(x.dtype)
# Interleaved layout: [FP8_data (head_dim) | SF (4 bytes)] per row
row_stride = head_dim + 4
x_fp8 = torch.empty((num_blocks, block_size, num_heads, row_stride), device=x.device, dtype=torch.uint8)
x_fp8[:, :, :, :head_dim] = x_scaled.view(torch.uint8)
x_fp8[:, :, :, head_dim:] = sf.view(num_blocks, block_size, 1, 1).expand(-1, -1, 1, 4).view(torch.uint8)
return x_fp8, x_cast_back.to(x.dtype)

def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor:
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1 and head_dim == 128
x_scaled, sf = per_token_cast_to_fp4(x.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
x_cast_back = cast_back_from_fp4(x_scaled, sf, gran_k=32, use_packed_ue8m0=True).view(num_blocks, block_size, 1, head_dim)

x_fp4 = torch.empty((num_blocks, block_size * (head_dim // 2 + 4)), device=x.device, dtype=torch.uint8)
x_fp4[ :, : block_size * head_dim // 2] = x_scaled.view(num_blocks, block_size * head_dim // 2).view(torch.uint8)
x_fp4[ :, block_size * head_dim // 2 :] = sf.view(num_blocks, block_size).view(torch.uint8)
return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), x_cast_back.to(x.dtype)
# Interleaved layout: [FP4_data (head_dim//2) | SF (4 bytes)] per row
row_stride = head_dim // 2 + 4
x_fp4 = torch.empty((num_blocks, block_size, num_heads, row_stride), device=x.device, dtype=torch.uint8)
x_fp4[:, :, :, :head_dim // 2] = x_scaled.view(torch.uint8)
x_fp4[:, :, :, head_dim // 2:] = sf.view(num_blocks, block_size, 1, 1).expand(-1, -1, 1, 4).view(torch.uint8)
return x_fp4, x_cast_back.to(x.dtype)

def enumerate_paged_mqa_logits():
arch_major = get_arch_major()
Expand Down