diff --git a/CMakeLists.txt b/CMakeLists.txt index 79f1964d..04001d10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,8 +18,8 @@ find_package(CUDAToolkit REQUIRED) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CUDA_STANDARD 20) include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 2aa27066..9792f49d 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "compatibility.hpp" @@ -17,20 +18,25 @@ class DGException final : public std::exception { message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; } + explicit DGException(const char *name, const char* file, const int line, const std::string& error, const std::source_location& caller) { + message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error + + "\ncaller: " + caller.file_name() + '(' + caller.line() + ')' + caller.function_name(); + } + const char *what() const noexcept override { return message.c_str(); } }; #ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond __VA_OPT__(,) __VA_ARGS__) #endif #ifndef DG_HOST_ASSERT -#define DG_HOST_ASSERT(cond) \ +#define DG_HOST_ASSERT(cond, ...) \ do { \ if (not (cond)) { \ - throw DGException("Assertion", __FILE__, __LINE__, #cond); \ + throw DGException("Assertion", __FILE__, __LINE__, #cond __VA_OPT__(,) __VA_ARGS__); \ } \ } while (0) #endif diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp index d67cfcfb..0be986f6 100644 --- a/csrc/utils/layout.hpp +++ b/csrc/utils/layout.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "math.hpp" #include "exception.hpp" @@ -10,22 +11,22 @@ namespace deep_gemm { // Major-ness stuffs -static void major_check(const torch::Tensor& t) { +static void major_check(const torch::Tensor& t, const std::source_location location = std::source_location::current()) { const auto dim = t.dim(); - DG_HOST_ASSERT(dim == 2 or dim == 3); + DG_HOST_ASSERT(dim == 2 or dim == 3, location); if (dim == 3) - DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); - DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); + DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1), location); + DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1, location); } -static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { - major_check(t); +static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t, const std::source_location location = std::source_location::current()) { + major_check(t, location); return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; } -static void check_major_type_cd(const torch::Tensor& t) { +static void check_major_type_cd(const torch::Tensor& t, const std::source_location location = std::source_location::current()) { // NOTES: the library only supports row-major output layouts - major_check(t); + major_check(t, location); DG_HOST_ASSERT(t.stride(-1) == 1); } @@ -35,15 +36,15 @@ static bool fp8_requires_k_major() { // Tensor utils template -static auto get_shape(const torch::Tensor& t) { - DG_HOST_ASSERT(t.dim() == N); +static auto get_shape(const torch::Tensor& t, const std::source_location location = std::source_location::current()) { + DG_HOST_ASSERT(t.dim() == N, location); return [&t] (std::index_sequence) { return std::make_tuple(static_cast(t.sizes()[Is])...); }(std::make_index_sequence()); } -static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { - auto [mn, k] = get_shape<2>(ab); +static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major, const std::source_location location = std::source_location::current()) { + auto [mn, k] = get_shape<2>(ab, location); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); @@ -51,8 +52,8 @@ static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute return std::make_tuple(mn, k); } -static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { - auto [num_groups, mn, k] = get_shape<3>(ab); +static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major, const std::source_location location = std::source_location::current()) { + auto [num_groups, mn, k] = get_shape<3>(ab, location); if (ab.scalar_type() != torch::kFloat8_e4m3fn) { DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); @@ -83,35 +84,36 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf, const std::optional& num_groups, const bool& tma_stride_check = false, const bool& sm90_sfb_check = false, - const std::optional& type_check = std::nullopt) { + const std::optional& type_check = std::nullopt, + const std::source_location location = std::source_location::current()) { // Type check if (type_check.has_value()) - DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); + DG_HOST_ASSERT(sf.scalar_type() == type_check.value(), location); // Always do shape checks const auto sf_dtype = sf.scalar_type(); - DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); - DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); + DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt, location); + DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2, location); if (num_groups.has_value()) - DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); - DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); - DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); + DG_HOST_ASSERT(sf.size(-3) == num_groups.value(), location); + DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn), location); + DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)), location); // TMA stride checks: TMA aligned and MN-major if (tma_stride_check) { if (num_groups.has_value()) - DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); + DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1), location); // Check contiguity in the MN direction - DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1); - DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); + DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1, location); + DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()), location); } // SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions if (sm90_sfb_check) { if (num_groups.has_value()) - DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1)); + DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1), location); DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or - (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1)); + (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1), location); } return sf; }