diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 38ff225d..c15d19ca 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -263,13 +263,13 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple() + block_kv * head_dim / 2, + fused_kv_cache.data_ptr() + head_dim / 2, {num_kv_blocks, block_kv}, - {kv_cache_stride_bytes / static_cast(sizeof(int)), 1}, + {kv_cache_stride_bytes / static_cast(sizeof(int)), fp4_with_sf_bytes / static_cast(sizeof(int))}, torch::TensorOptions().dtype(torch::kInt32) ); } else { @@ -295,13 +295,13 @@ static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple() + block_kv * head_dim, + fused_kv_cache.data_ptr() + head_dim, {num_kv_blocks, block_kv}, - {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, + {kv_cache_stride_bytes / static_cast(sizeof(float)), head_dim_with_sf / static_cast(sizeof(float))}, torch::TensorOptions().dtype(torch::kFloat32) ); diff --git a/tests/test_attention.py b/tests/test_attention.py index 6df1fc4f..d7885a7c 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -235,10 +235,12 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) x_cast_back = x_scaled.float() * sf - x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) - x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(torch.uint8) - x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(torch.uint8) - return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), x_cast_back.to(x.dtype) + # Interleaved layout: [FP8_data (head_dim) | SF (4 bytes)] per row + row_stride = head_dim + 4 + x_fp8 = torch.empty((num_blocks, block_size, num_heads, row_stride), device=x.device, dtype=torch.uint8) + x_fp8[:, :, :, :head_dim] = x_scaled.view(torch.uint8) + x_fp8[:, :, :, head_dim:] = sf.view(num_blocks, block_size, 1, 1).expand(-1, -1, 1, 4).view(torch.uint8) + return x_fp8, x_cast_back.to(x.dtype) def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor: num_blocks, block_size, num_heads, head_dim = x.shape @@ -246,10 +248,12 @@ def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor: x_scaled, sf = per_token_cast_to_fp4(x.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) x_cast_back = cast_back_from_fp4(x_scaled, sf, gran_k=32, use_packed_ue8m0=True).view(num_blocks, block_size, 1, head_dim) - x_fp4 = torch.empty((num_blocks, block_size * (head_dim // 2 + 4)), device=x.device, dtype=torch.uint8) - x_fp4[ :, : block_size * head_dim // 2] = x_scaled.view(num_blocks, block_size * head_dim // 2).view(torch.uint8) - x_fp4[ :, block_size * head_dim // 2 :] = sf.view(num_blocks, block_size).view(torch.uint8) - return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), x_cast_back.to(x.dtype) + # Interleaved layout: [FP4_data (head_dim//2) | SF (4 bytes)] per row + row_stride = head_dim // 2 + 4 + x_fp4 = torch.empty((num_blocks, block_size, num_heads, row_stride), device=x.device, dtype=torch.uint8) + x_fp4[:, :, :, :head_dim // 2] = x_scaled.view(torch.uint8) + x_fp4[:, :, :, head_dim // 2:] = sf.view(num_blocks, block_size, 1, 1).expand(-1, -1, 1, 4).view(torch.uint8) + return x_fp4, x_cast_back.to(x.dtype) def enumerate_paged_mqa_logits(): arch_major = get_arch_major()