豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
Comment on lines +21 to +22
Copy link
Copy Markdown
Author

@YouJiacheng YouJiacheng Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both std::source_location and __VA_OPT__ need C++20.


include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
Expand Down
12 changes: 9 additions & 3 deletions csrc/utils/exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <exception>
#include <string>
#include <sstream>
#include <source_location>

#include "compatibility.hpp"

Expand All @@ -17,20 +18,25 @@ class DGException final : public std::exception {
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error;
}

explicit DGException(const char *name, const char* file, const int line, const std::string& error, const std::source_location& caller) {
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error
+ "\ncaller: " + caller.file_name() + '(' + caller.line() + ')' + caller.function_name();
}

const char *what() const noexcept override {
return message.c_str();
}
};

#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond __VA_OPT__(,) __VA_ARGS__)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without __VA_OPT__, DG_STATIC_ASSERT will fail to compile if __VA_ARGS__ is empty (i.e. no message is provided)

#endif

#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
#define DG_HOST_ASSERT(cond, ...) \
do { \
if (not (cond)) { \
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
throw DGException("Assertion", __FILE__, __LINE__, #cond __VA_OPT__(,) __VA_ARGS__); \
} \
} while (0)
#endif
Expand Down
54 changes: 28 additions & 26 deletions csrc/utils/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cute/arch/mma_sm100_umma.hpp>
#include <torch/python.h>
#include <source_location>

#include "math.hpp"
#include "exception.hpp"
Expand All @@ -10,22 +11,22 @@
namespace deep_gemm {

// Major-ness stuffs
static void major_check(const torch::Tensor& t) {
static void major_check(const torch::Tensor& t, const std::source_location location = std::source_location::current()) {
const auto dim = t.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(dim == 2 or dim == 3, location);
if (dim == 3)
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1), location);
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1, location);
}

static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
major_check(t);
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t, const std::source_location location = std::source_location::current()) {
major_check(t, location);
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
}

static void check_major_type_cd(const torch::Tensor& t) {
static void check_major_type_cd(const torch::Tensor& t, const std::source_location location = std::source_location::current()) {
// NOTES: the library only supports row-major output layouts
major_check(t);
major_check(t, location);
DG_HOST_ASSERT(t.stride(-1) == 1);
}

Expand All @@ -35,24 +36,24 @@ static bool fp8_requires_k_major() {

// Tensor utils
template <int N>
static auto get_shape(const torch::Tensor& t) {
DG_HOST_ASSERT(t.dim() == N);
static auto get_shape(const torch::Tensor& t, const std::source_location location = std::source_location::current()) {
DG_HOST_ASSERT(t.dim() == N, location);
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
}(std::make_index_sequence<N>());
}

static std::tuple<int, int> check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [mn, k] = get_shape<2>(ab);
static std::tuple<int, int> check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major, const std::source_location location = std::source_location::current()) {
auto [mn, k] = get_shape<2>(ab, location);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
}
return std::make_tuple(mn, k);
}

static std::tuple<int, int, int> check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [num_groups, mn, k] = get_shape<3>(ab);
static std::tuple<int, int, int> check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major, const std::source_location location = std::source_location::current()) {
auto [num_groups, mn, k] = get_shape<3>(ab, location);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
Expand Down Expand Up @@ -83,35 +84,36 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf,
const std::optional<int>& num_groups,
const bool& tma_stride_check = false,
const bool& sm90_sfb_check = false,
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
const std::optional<torch::ScalarType>& type_check = std::nullopt,
const std::source_location location = std::source_location::current()) {
// Type check
if (type_check.has_value())
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
DG_HOST_ASSERT(sf.scalar_type() == type_check.value(), location);

// Always do shape checks
const auto sf_dtype = sf.scalar_type();
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt, location);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2, location);
if (num_groups.has_value())
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
DG_HOST_ASSERT(sf.size(-3) == num_groups.value(), location);
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn), location);
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)), location);

// TMA stride checks: TMA aligned and MN-major
if (tma_stride_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1), location);
// Check contiguity in the MN direction
DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1);
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1, location);
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()), location);
}

// SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions
if (sm90_sfb_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1));
DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1), location);
DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or
(sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1));
(sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1), location);
}
return sf;
}
Expand Down