7 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   23         total_sk_iters_ = sk_tiles_ * iters_per_tile_;
 
   26         if(total_sk_iters_ < 
grid_)
 
   40     iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
 
   41     extra_iters_      = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
 
   44     total_dp_iters_ = 
dp_tiles_ * iters_per_tile_;
 
   47 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   49 StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
 
   50     index_t acc_element_bytes) 
const noexcept
 
   52     return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
 
   55 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   57 StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
 
   60     return sizeof(
index_t) * sk_ctas_;
 
   63 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   69     iter     = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
 
   70     iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
 
   73 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   78     return iter / iters_per_tile_;
 
   81 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   86     tile_iter     = tile_idx * iters_per_tile_;
 
   87     tile_iter_end = tile_iter + iters_per_tile_;
 
   90 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
   95     return iter - tile_iter;
 
   98 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  103     return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
 
  106 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  118 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  121     index_t acc_element_bytes) 
const noexcept
 
  126         return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
 
  134 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  142 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  149 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  156 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  163 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  170 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  175     return total_sk_iters_;
 
  178 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  183     return iters_per_tile_;
 
  186 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  191     return iters_per_sk_cta_;
 
  194 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  202 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  207     return total_dp_iters_;
 
  210 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  217 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  224     int num_wgs_per_tile = 1;
 
  231         num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
 
  232                            ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
 
  235     return std::max(num_wgs_per_tile, 1);
 
  238 template <
typename BlockGemmShapeType,
 
  244 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  252     dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
 
  253     extra_dp_tiles_   = this->dp_tiles_ % this->grid_;
 
  256 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  259     const noexcept -> dim3
 
  261     if(extra_dp_tiles_ == 0)
 
  263         return dim3(this->grid_, 1, 1);
 
  267         return dim3(this->num_tiles_, 1, 1);
 
  271 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  276     return dp_tiles_per_cta_;
 
  279 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  284     return extra_dp_tiles_;
 
  288 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  296     dp_ctas_            = this->dp_tiles_;
 
  297     dp_start_block_idx_ = 0;
 
  298     sk_start_block_idx_ = this->dp_tiles_;
 
  301 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  304     const noexcept -> dim3
 
  306     return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
 
  309 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  317 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  322     return dp_start_block_idx_;
 
  325 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
 
  330     return sk_start_block_idx_;
 
#define CK_TILE_DEVICE
Definition: config.hpp:41
 
#define CK_TILE_HOST
Definition: config.hpp:40
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
 
StreamKReductionStrategy
Definition: streamk_common.hpp:10
 
int32_t index_t
Definition: integer.hpp:9
 
ck_tile::index_t estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
Estimates the number of Stream-K workgroups per macro tile in the C tensor.
Definition: streamk_common.hpp:27
 
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
 
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
 
@ Atomic
Definition: block_to_ctile_map.hpp:1012
 
@ Reduction
Definition: block_to_ctile_map.hpp:1013
 
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:299
 
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:231
 
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
 
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:196
 
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:29
 
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
 
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
 
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:195
 
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:8
 
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:197
 
Definition: tuple.hpp:192