|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <torch/python.h> |
| 4 | + |
| 5 | +#include "../../jit/compiler.hpp" |
| 6 | +#include "../../jit/device_runtime.hpp" |
| 7 | +#include "../../jit/kernel_runtime.hpp" |
| 8 | +#include "../../utils/exception.hpp" |
| 9 | +#include "../../utils/format.hpp" |
| 10 | +#include "../../utils/math.hpp" |
| 11 | +#include "../heuristics/sm90.hpp" |
| 12 | +#include "runtime_utils.hpp" |
| 13 | + |
| 14 | +namespace deep_gemm { |
| 15 | + |
| 16 | +class SM90BF16HCPrenormBwdGemmRuntime final: public LaunchRuntime<SM90BF16HCPrenormBwdGemmRuntime> { |
| 17 | +public: |
| 18 | + struct Args { |
| 19 | + int m, n, k; |
| 20 | + int block_m, block_n, block_k; |
| 21 | + int swizzle_n_mode, swizzle_k_mode, swizzle_m_mode; |
| 22 | + int accumulate_on_da; |
| 23 | + int num_stages; |
| 24 | + int num_tma_threads, num_da_threads, num_db_threads; |
| 25 | + |
| 26 | + LaunchArgs launch_args; |
| 27 | + |
| 28 | + float* ds; |
| 29 | + float* db; |
| 30 | + CUtensorMap tensor_map_a; |
| 31 | + CUtensorMap tensor_map_b; |
| 32 | + CUtensorMap tensor_map_dd; |
| 33 | + CUtensorMap tensor_map_da; |
| 34 | + }; |
| 35 | + |
| 36 | + static std::string generate_impl(const Args& args) { |
| 37 | + return fmt::format(R"( |
| 38 | +#include <deep_gemm/impls/sm90_tf32_hc_prenorm_bwd_gemm.cuh> |
| 39 | +
|
| 40 | +using namespace deep_gemm; |
| 41 | +
|
| 42 | +static void __instantiate_kernel() {{ |
| 43 | + auto ptr = reinterpret_cast<void*>(&sm90_tf32_hc_prenorm_bwd_gemm_impl< |
| 44 | + {}, {}, |
| 45 | + {}, {}, {}, |
| 46 | + {}, {}, {}, |
| 47 | + {}, |
| 48 | + {}, |
| 49 | + {}, {}, {} |
| 50 | + >); |
| 51 | +}}; |
| 52 | +)", |
| 53 | + args.n, args.k, |
| 54 | + args.block_m, args.block_n, args.block_k, |
| 55 | + args.swizzle_n_mode, args.swizzle_k_mode, args.swizzle_m_mode, |
| 56 | + args.accumulate_on_da, |
| 57 | + args.num_stages, |
| 58 | + args.num_tma_threads, args.num_da_threads, args.num_db_threads); |
| 59 | + } |
| 60 | + |
| 61 | + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { |
| 62 | + // TODO: optimize `args` copy |
| 63 | + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, |
| 64 | + args.m, args.ds, args.db, |
| 65 | + args.tensor_map_a, args.tensor_map_b, args.tensor_map_dd, |
| 66 | + args.tensor_map_da)); |
| 67 | + } |
| 68 | +}; |
| 69 | + |
| 70 | +static void sm90_tf32_hc_prenorm_bwd_gemm(const torch::Tensor& a, |
| 71 | + const torch::Tensor& b, |
| 72 | + const torch::Tensor& dd, |
| 73 | + const torch::Tensor& ds, |
| 74 | + const torch::Tensor& da, |
| 75 | + const torch::Tensor& db, |
| 76 | + const int& m, const int& n, const int& k, |
| 77 | + const bool& accumulate_on_da) { |
| 78 | + constexpr int block_m = 128; |
| 79 | + const int block_n = align(n, 16); |
| 80 | + DG_HOST_ASSERT(n <= block_n); |
| 81 | + // Only support small N for now |
| 82 | + DG_HOST_ASSERT(n <= 32 and n % 8 == 0); |
| 83 | + constexpr int block_k = 64; |
| 84 | + |
| 85 | + constexpr int num_tma_threads = 128; |
| 86 | + constexpr int num_da_threads = 128; |
| 87 | + constexpr int num_db_threads = 128; |
| 88 | + |
| 89 | + // NOTES: block K must be large enough (>= 64) to ensure TF32 and BF16 swizzling are the same |
| 90 | + const auto& swizzle_n_mode = get_swizzle_mode(block_n, sizeof(float)); |
| 91 | + const auto& swizzle_k_mode = get_swizzle_mode(block_k, sizeof(float)); |
| 92 | + const auto& swizzle_m_mode = get_swizzle_mode(block_m, sizeof(float)); |
| 93 | + DG_HOST_ASSERT(swizzle_k_mode == get_swizzle_mode(block_k, sizeof(nv_bfloat16))); // for tma_a (BF16) |
| 94 | + |
| 95 | + const auto tensor_map_a = make_tma_b_desc(cute::UMMA::Major::MN, a, k, m, |
| 96 | + block_k, block_m, |
| 97 | + static_cast<int>(a.stride(0)), 1, |
| 98 | + swizzle_k_mode); |
| 99 | + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, k, n, |
| 100 | + block_k, block_n, |
| 101 | + static_cast<int>(b.stride(0)), 1, |
| 102 | + swizzle_k_mode, 0, true); |
| 103 | + const auto tensor_map_dd = make_tma_a_desc(cute::UMMA::Major::MN, dd, n, m, |
| 104 | + block_n, block_m, |
| 105 | + static_cast<int>(dd.stride(0)), 1, |
| 106 | + swizzle_n_mode, 0, true); |
| 107 | + const auto tensor_map_da = make_tma_cd_desc(da, m, k, // (m, k) k inner major |
| 108 | + block_m, block_k, |
| 109 | + static_cast<int>(da.stride(0)), 1, |
| 110 | + swizzle_k_mode); |
| 111 | + |
| 112 | + // Calculate stages |
| 113 | + int num_stages = 16, smem_size = 0; |
| 114 | + while (num_stages > 0) { |
| 115 | + const int smem_dd_per_stage = block_m * block_n * sizeof(float); |
| 116 | + const int smem_a_per_stage = block_m * block_k * sizeof(nv_bfloat16); |
| 117 | + const int smem_b = block_n * block_k * sizeof(float); |
| 118 | + const int smem_ds_per_stage = block_m * sizeof(float); |
| 119 | + const int smem_da = block_m * block_k * sizeof(nv_bfloat16) * 2; |
| 120 | + const int smem_barriers = (num_stages * 4 + 1) * 8; |
| 121 | + smem_size = (smem_dd_per_stage + smem_a_per_stage + smem_ds_per_stage) * num_stages + |
| 122 | + smem_da + smem_b + smem_barriers; |
| 123 | + |
| 124 | + if (smem_size <= SM90ArchSpec::smem_capacity) |
| 125 | + break; |
| 126 | + -- num_stages; |
| 127 | + } |
| 128 | + DG_HOST_ASSERT(num_stages > 0); |
| 129 | + |
| 130 | + // Print configs |
| 131 | + if (get_env("DG_JIT_DEBUG", 0)) { |
| 132 | + printf("M: %d, N: %d, K: %d -> " |
| 133 | + "block M: %d, block N: %d, block K: %d, " |
| 134 | + "stages: %d, shared memory: %d, " |
| 135 | + "swizzle N: %d, swizzle K: %d\n, swizzle M: %d\n", |
| 136 | + m, n, k, block_m, block_n, block_k, |
| 137 | + num_stages, smem_size, |
| 138 | + swizzle_n_mode, swizzle_k_mode, swizzle_m_mode); |
| 139 | + } |
| 140 | + |
| 141 | + // Launch |
| 142 | + const SM90BF16HCPrenormBwdGemmRuntime::Args& args = { |
| 143 | + .m = m, .n = n, .k = k, |
| 144 | + .block_m = block_m, .block_n = block_n, .block_k = block_k, |
| 145 | + .swizzle_n_mode = swizzle_n_mode, .swizzle_k_mode = swizzle_k_mode, .swizzle_m_mode = swizzle_m_mode, |
| 146 | + .accumulate_on_da = static_cast<int>(accumulate_on_da), |
| 147 | + .num_stages = num_stages, |
| 148 | + .num_tma_threads = num_tma_threads, |
| 149 | + .num_da_threads = num_da_threads, |
| 150 | + .num_db_threads = num_db_threads, |
| 151 | + .launch_args = LaunchArgs(ceil_div(k, block_k), num_tma_threads + num_da_threads + num_db_threads, smem_size, 1), |
| 152 | + .ds = ds.data_ptr<float>(), |
| 153 | + .db = db.data_ptr<float>(), |
| 154 | + .tensor_map_a = tensor_map_a, |
| 155 | + .tensor_map_b = tensor_map_b, |
| 156 | + .tensor_map_dd = tensor_map_dd, |
| 157 | + .tensor_map_da = tensor_map_da, |
| 158 | + }; |
| 159 | + const auto code = SM90BF16HCPrenormBwdGemmRuntime::generate(args); |
| 160 | + const auto runtime = compiler->build("sm90_tf32_hc_prenorm_bwd_gemm", code); |
| 161 | + SM90BF16HCPrenormBwdGemmRuntime::launch(runtime, args); |
| 162 | +} |
| 163 | + |
| 164 | +} // namespace deep_gemm |
0 commit comments