@@ -8,31 +8,31 @@ constexpr static uint32_t kNumMaxRanks = 64;
88
99template <uint32_t kNumRanks = kNumMaxRanks >
1010struct SymBuffer {
11- uint64_t offsets[ kNumMaxRanks ] ;
12-
13- uint32_t rank_idx = 0 ;
11+ int64_t base ;
12+ int64_t offsets[ kNumMaxRanks ];
13+ uint32_t rank_idx;
1414
1515 DG_STATIC_ASSERT (kNumRanks <= kNumMaxRanks , " Too many ranks" );
1616
1717 SymBuffer () = default ;
1818
1919 template <typename Container>
20- explicit SymBuffer (const Container& c, const uint32_t & rank_idx = 0 ): rank_idx(rank_idx) {
20+ explicit SymBuffer (const Container& c, const uint32_t & rank_idx): rank_idx(rank_idx) {
2121 const auto size = static_cast <uint32_t >(c.size ());
22+ base = c[rank_idx];
2223 for (uint32_t i = 0 ; i < kNumMaxRanks ; ++ i)
23- offsets[i] = i < size ? c[i] : 0 ;
24+ offsets[i] = i < size ? ( c[i] - base) : 0 ;
2425 }
2526
2627#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
2728 template <typename ptr_t = void *>
2829 CUTLASS_DEVICE ptr_t get_base_ptr () const {
29- return reinterpret_cast <ptr_t >(offsets[rank_idx] );
30+ return reinterpret_cast <ptr_t >(base );
3031 }
3132
3233 template <typename ptr_t >
3334 CUTLASS_DEVICE ptr_t map (const ptr_t & ptr, const uint32_t & dst_rank_idx) const {
34- uint64_t mapped_ptr = offsets[dst_rank_idx] +
35- (reinterpret_cast <uint64_t >(ptr) - offsets[rank_idx]);
35+ int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast <int64_t >(ptr);
3636 return *reinterpret_cast <ptr_t *>(&mapped_ptr);
3737 }
3838#endif
0 commit comments