@@ -403,69 +403,4 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
403403 }
404404}
405405
406- static void register_apis (pybind11::module_& m) {
407- // FP8 GEMMs
408- m.def (" fp8_gemm_nt" , &fp8_gemm_nt,
409- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
410- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
411- py::arg (" compiled_dims" ) = " nk" ,
412- py::arg (" disable_ue8m0_cast" ) = false );
413- m.def (" fp8_gemm_nn" , &fp8_gemm_nn,
414- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
415- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
416- py::arg (" compiled_dims" ) = " nk" ,
417- py::arg (" disable_ue8m0_cast" ) = false );
418- m.def (" fp8_gemm_tn" , &fp8_gemm_tn,
419- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
420- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
421- py::arg (" compiled_dims" ) = " mn" ,
422- py::arg (" disable_ue8m0_cast" ) = false );
423- m.def (" fp8_gemm_tt" , &fp8_gemm_tt,
424- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
425- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
426- py::arg (" compiled_dims" ) = " mn" ,
427- py::arg (" disable_ue8m0_cast" ) = false );
428- m.def (" m_grouped_fp8_gemm_nt_contiguous" , &m_grouped_fp8_gemm_nt_contiguous,
429- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
430- py::arg (" recipe" ) = std::nullopt , py::arg (" compiled_dims" ) = " nk" ,
431- py::arg (" disable_ue8m0_cast" ) = false );
432- m.def (" m_grouped_fp8_gemm_nn_contiguous" , &m_grouped_fp8_gemm_nn_contiguous,
433- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
434- py::arg (" recipe" ) = std::nullopt , py::arg (" compiled_dims" ) = " nk" ,
435- py::arg (" disable_ue8m0_cast" ) = false );
436- m.def (" m_grouped_fp8_gemm_nt_masked" , &m_grouped_fp8_gemm_nt_masked,
437- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" masked_m" ),
438- py::arg (" expected_m" ), py::arg (" recipe" ) = std::nullopt ,
439- py::arg (" compiled_dims" ) = " nk" , py::arg (" disable_ue8m0_cast" ) = false );
440- m.def (" k_grouped_fp8_gemm_tn_contiguous" , &k_grouped_fp8_gemm_tn_contiguous,
441- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" ks" ),
442- py::arg (" ks_tensor" ), py::arg (" c" ) = std::nullopt ,
443- py::arg (" recipe" ) = std::make_tuple (1 , 1 , 128 ),
444- py::arg (" compiled_dims" ) = " mn" );
445-
446- // BF16 GEMMs
447- m.def (" bf16_gemm_nt" , &bf16_gemm_nt,
448- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
449- py::arg (" c" ) = std::nullopt ,
450- py::arg (" compiled_dims" ) = " nk" );
451- m.def (" bf16_gemm_nn" , &bf16_gemm_nn,
452- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
453- py::arg (" c" ) = std::nullopt ,
454- py::arg (" compiled_dims" ) = " nk" );
455- m.def (" bf16_gemm_tn" , &bf16_gemm_tn,
456- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
457- py::arg (" c" ) = std::nullopt ,
458- py::arg (" compiled_dims" ) = " mn" );
459- m.def (" bf16_gemm_tt" , &bf16_gemm_tt,
460- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
461- py::arg (" c" ) = std::nullopt ,
462- py::arg (" compiled_dims" ) = " mn" );
463- m.def (" m_grouped_bf16_gemm_nt_contiguous" , &m_grouped_bf16_gemm_nt_contiguous,
464- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
465- py::arg (" compiled_dims" ) = " nk" );
466- m.def (" m_grouped_bf16_gemm_nt_masked" , &m_grouped_bf16_gemm_nt_masked,
467- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" masked_m" ),
468- py::arg (" expected_m" ), py::arg (" compiled_dims" ) = " nk" );
469- }
470-
471- } // namespace deep_gemm::gemm
406+ } // namespace deep_gemm::gemm
0 commit comments