@@ -33,7 +33,6 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
3333 const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
3434#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
3535 using Barrier = cutlass::arch::ClusterTransactionBarrier;
36- using Allocator = cute::conditional_t <kNumMulticast == 1 , cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
3736
3837 // GEMM with accumulation must have FP32 output
3938 if constexpr (kWithAccumulation )
@@ -170,7 +169,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
170169 cutlass::arch::fence_barrier_init ();
171170 } else if (threadIdx .x >= 32 and threadIdx .x < 64 ) {
172171 // Allocate tensor memory
173- Allocator ().allocate (kNumTmemCols , tmem_ptr_in_smem);
172+ cute::TMEM::Allocator1Sm ().allocate (kNumTmemCols , tmem_ptr_in_smem);
174173 }
175174 kNumMulticast > 1 ? cute::cluster_sync () : __syncthreads ();
176175
@@ -578,13 +577,15 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
578577 }
579578
580579 // Flush all stages in the pipeline to make TMA stores visible to the next kernel
580+ // TODO: do we actually need this?
581581 if (epilogue_thread_idx == 0 )
582582 cute::tma_store_wait<0 >();
583583
584584 // Deallocate tensor memory by warp 1
585585 // NOTES: warp 0 is waiting TMA store
586+ // TODO: do we need 2 SM allocation?
586587 if (epilogue_warp_idx == 1 )
587- Allocator ().free (0 , kNumTmemCols );
588+ cute::TMEM::Allocator1Sm ().free (0 , kNumTmemCols );
588589 }
589590
590591 // To safely deconstruct all barriers, we need a cluster sync
0 commit comments