20 template <
typename ABDataType,
21 typename FloatGemmAcc,
22 typename EDataTypeShuffle,
24 typename AElementwiseOperation,
25 typename BElementwiseOperation,
26 typename EElementwiseOperation,
28 typename AGridDesc_M_K,
29 typename BGridDesc_N_K,
30 typename EGridDesc_M_N,
32 index_t TileLoadThreadGroupSize,
33 index_t TileMathThreadGroupSize,
43 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 index_t ABlockTransferSrcVectorDim,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t ABlockTransferDstScalarPerVector_AK1,
49 bool AThreadTransferSrcResetCoordinateAfterRun,
51 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
52 typename BBlockTransferThreadClusterArrangeOrder,
53 typename BBlockTransferSrcAccessOrder,
54 index_t BBlockTransferSrcVectorDim,
55 index_t BBlockTransferSrcScalarPerVector,
56 index_t BBlockTransferDstScalarPerVector_BK1,
57 bool BThreadTransferSrcResetCoordinateAfterRun,
59 index_t CShuffleMXdlPerWavePerShuffle,
60 index_t CShuffleNXdlPerWavePerShuffle,
61 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
62 index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
131 __host__ __device__
static constexpr
auto
134 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
135 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
137 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
144 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
157 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
160 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
163 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
166 constexpr
auto c_block_size =
167 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
169 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
171 c_block_size *
sizeof(EDataTypeShuffle));
175 template <
typename Block2ETileMap>
176 __host__ __device__
static constexpr
bool
178 const BGridDesc_N_K& b_grid_desc_n_k,
179 const EGridDesc_M_N& e_grid_desc_m_n,
180 const Block2ETileMap& )
182 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
183 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
184 "Invalid tuning param!");
186 const auto M = a_grid_desc_m_k.GetLength(
I0);
187 const auto N = b_grid_desc_n_k.GetLength(
I0);
188 const auto K = a_grid_desc_m_k.GetLength(
I1);
191 if(!(M == e_grid_desc_m_n.GetLength(
I0) && N == e_grid_desc_m_n.GetLength(
I1) &&
192 K == b_grid_desc_n_k.GetLength(
I1)))
198 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
204 const auto num_k_loop = K / KPerBlock;
206 if(!GridwiseGemmMath::IsSupported(num_k_loop))
216 if(!(a_grid_desc_m_k.GetElementSpaceSize() *
sizeof(ABDataType) <= TwoGB &&
217 b_grid_desc_n_k.GetElementSpaceSize() *
sizeof(ABDataType) <= TwoGB &&
218 e_grid_desc_m_n.GetElementSpaceSize() *
sizeof(EDataType) <= TwoGB))
228 const index_t num_loop = K / KPerBlock;
230 return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
234 __host__ __device__
static constexpr
auto
237 const auto M = e_grid_desc_m_n.GetLength(
I0);
238 const auto N = e_grid_desc_m_n.GetLength(
I1);
243 const auto M0 = M / M1;
244 const auto N0 = N / N1;
246 constexpr
auto M01 =
I1;
247 constexpr
auto N01 =
I1;
249 const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
256 const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
262 const auto cblockid_to_m0_n0_block_cluster_adaptor =
264 cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
266 return cblockid_to_m0_n0_block_cluster_adaptor;
269 __host__ __device__
static constexpr
index_t
272 const auto M = e_grid_desc_m_n.GetLength(
I0);
273 const auto N = e_grid_desc_m_n.GetLength(
I1);
275 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
281 __host__ __device__
static constexpr
auto
284 const auto M = a_grid_desc_m_k.GetLength(
I0);
285 const auto K = a_grid_desc_m_k.GetLength(
I1);
287 const auto AK0 = K /
AK1;
297 __host__ __device__
static constexpr
auto
300 const auto N = b_grid_desc_n_k.GetLength(
I0);
301 const auto K = b_grid_desc_n_k.GetLength(
I1);
303 const auto BK0 = K /
BK1;
313 template <
typename EGr
idDescriptor_M_N>
315 const EGridDescriptor_M_N& e_grid_desc_m_n)
317 const auto M = e_grid_desc_m_n.GetLength(
I0);
318 const auto N = e_grid_desc_m_n.GetLength(
I1);
320 const auto MBlock = M / MPerBlock;
321 const auto NBlock = N / NPerBlock;
330 return e_grid_desc_mblock_mperblock_nblock_nperblock;
340 template <
bool HasMainKBlockLoop,
341 typename AGridDesc_AK0_M_AK1,
342 typename BGridDesc_BK0_N_BK1,
343 typename Block2ETileMap>
344 __device__
static void Run(
const ABDataType* __restrict__ p_a_grid,
345 const ABDataType* __restrict__ p_b_grid,
346 EDataType* __restrict__ p_e_grid,
347 void* __restrict__ p_shared,
348 const AElementwiseOperation& a_element_op,
349 const BElementwiseOperation& b_element_op,
350 const EElementwiseOperation& e_element_op,
351 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
352 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
354 e_grid_desc_mblock_mperblock_nblock_nperblock,
355 const Block2ETileMap& block_2_etile_map)
371 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
373 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
374 static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
376 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
377 static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
378 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
383 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
384 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
388 const auto block_work_idx =
392 const index_t m_block_data_idx_on_grid =
393 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
395 const index_t n_block_data_idx_on_grid =
396 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
402 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
403 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
404 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
405 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
408 auto a_blockwise_copy =
410 AElementwiseOperation,
414 ABlockTransferThreadClusterLengths_AK0_M_AK1,
415 ABlockTransferThreadClusterArrangeOrder,
418 decltype(a_grid_desc_ak0_m_ak1),
419 decltype(a_block_desc_ak0_m_ak1),
420 ABlockTransferSrcAccessOrder,
422 ABlockTransferSrcVectorDim,
424 ABlockTransferSrcScalarPerVector,
425 ABlockTransferDstScalarPerVector_AK1,
428 AThreadTransferSrcResetCoordinateAfterRun,
430 NumGemmKPrefetchStage>(
431 a_grid_desc_ak0_m_ak1,
434 a_block_desc_ak0_m_ak1,
439 auto b_blockwise_copy =
441 BElementwiseOperation,
445 BBlockTransferThreadClusterLengths_BK0_N_BK1,
446 BBlockTransferThreadClusterArrangeOrder,
449 decltype(b_grid_desc_bk0_n_bk1),
450 decltype(b_block_desc_bk0_n_bk1),
451 BBlockTransferSrcAccessOrder,
453 BBlockTransferSrcVectorDim,
455 BBlockTransferSrcScalarPerVector,
456 BBlockTransferDstScalarPerVector_BK1,
459 BThreadTransferSrcResetCoordinateAfterRun,
461 NumGemmKPrefetchStage>(
462 b_grid_desc_bk0_n_bk1,
465 b_block_desc_bk0_n_bk1,
469 GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
470 a_grid_desc_ak0_m_ak1,
471 a_block_desc_ak0_m_ak1,
475 a_block_slice_copy_step,
476 b_grid_desc_bk0_n_bk1,
477 b_block_desc_bk0_n_bk1,
481 b_block_slice_copy_step,
482 num_k_block_main_loop);
491 constexpr
bool is_single_rate_mfma =
499 constexpr
auto is_scale_mfma =
false;
507 is_scale_mfma>::selected_mfma.k_per_blk);
510 TileMathThreadGroupSize,
514 decltype(a_block_desc_ak0_m_ak1),
515 decltype(b_block_desc_bk0_n_bk1),
523 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
524 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
528 GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
529 a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
541 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
542 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
545 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
546 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
549 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
550 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
554 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
555 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
557 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
558 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
559 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
560 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
561 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
562 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
563 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
564 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
566 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
569 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
570 static_cast<EDataTypeShuffle*
>(p_shared),
571 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
574 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
596 const auto c_thread_mtx_on_block =
597 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
599 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
600 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
602 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
608 const auto m_thread_data_on_block_idx =
609 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
612 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
618 const auto n_thread_data_on_block_idx =
619 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
626 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
627 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
629 Sequence<CShuffleMXdlPerWavePerShuffle,
630 CShuffleNXdlPerWavePerShuffle,
642 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
645 m_thread_data_on_block_idx[
I1],
646 n_thread_data_on_block_idx[
I1],
647 m_thread_data_on_block_idx[
I2],
648 m_thread_data_on_block_idx[
I3],
649 m_thread_data_on_block_idx[
I4],
650 n_thread_data_on_block_idx[
I2]),
656 EElementwiseOperation,
657 CGlobalMemoryDataOperation,
659 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
661 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
662 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
666 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
667 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
670 CShuffleBlockTransferScalarPerVector_NPerBlock,
673 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
675 e_grid_desc_mblock_mperblock_nblock_nperblock,
680 constexpr
auto sfc_c_vgpr =
683 Sequence<CShuffleMXdlPerWavePerShuffle,
684 CShuffleNXdlPerWavePerShuffle,
693 constexpr
auto sfc_c_global =
697 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
699 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
701 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
703 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
732 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
733 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
735 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
736 c_shuffle_block_buf);
741 c_shuffle_block_copy_lds_to_global.Run(
742 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
744 e_grid_desc_mblock_mperblock_nblock_nperblock,
747 if constexpr(access_id < num_access - 1)
749 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
752 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
753 e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:52
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:79
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:82
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:90
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:85
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:97
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:105
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:100
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:235
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:132
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:147
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:115
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:123
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:108
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:344
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
Definition: gridwise_gemm_waveletmodel.hpp:11
Definition: gridwise_gemm_waveletmodel.hpp:103
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:308