20 template <
typename ABDataType,
21 typename FloatGemmAcc,
22 typename EDataTypeShuffle,
24 typename AElementwiseOperation,
25 typename BElementwiseOperation,
26 typename EElementwiseOperation,
28 typename AGridDesc_M_K,
29 typename BGridDesc_N_K,
30 typename EGridDesc_M_N,
32 index_t TileLoadThreadGroupSize,
33 index_t TileMathThreadGroupSize,
43 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 index_t ABlockTransferSrcVectorDim,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t ABlockTransferDstScalarPerVector_AK1,
49 bool AThreadTransferSrcResetCoordinateAfterRun,
51 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
52 typename BBlockTransferThreadClusterArrangeOrder,
53 typename BBlockTransferSrcAccessOrder,
54 index_t BBlockTransferSrcVectorDim,
55 index_t BBlockTransferSrcScalarPerVector,
56 index_t BBlockTransferDstScalarPerVector_BK1,
57 bool BThreadTransferSrcResetCoordinateAfterRun,
59 index_t CShuffleMXdlPerWavePerShuffle,
60 index_t CShuffleNXdlPerWavePerShuffle,
61 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
62 index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
132 __host__ __device__
static constexpr
auto
135 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
136 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
138 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
145 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
158 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
161 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
164 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
167 constexpr
auto c_block_size =
168 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
170 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
172 c_block_size *
sizeof(EDataTypeShuffle));
179 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
188 CGlobalMemoryDataOperation>();
192 template <
typename Block2ETileMap>
193 __host__ __device__
static constexpr
bool
195 const BGridDesc_N_K& b_grid_desc_n_k,
196 const EGridDesc_M_N& e_grid_desc_m_n,
197 const Block2ETileMap& )
199 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
200 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
201 "Invalid tuning param!");
203 const auto M = a_grid_desc_m_k.GetLength(
I0);
204 const auto N = b_grid_desc_n_k.GetLength(
I0);
205 const auto K = a_grid_desc_m_k.GetLength(
I1);
208 if(!(M == e_grid_desc_m_n.GetLength(
I0) && N == e_grid_desc_m_n.GetLength(
I1) &&
209 K == b_grid_desc_n_k.GetLength(
I1)))
215 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
221 const auto num_k_loop = K / KPerBlock;
223 if(!GridwiseGemmMath::IsSupported(num_k_loop))
233 if(!(a_grid_desc_m_k.GetElementSpaceSize() *
sizeof(ABDataType) <= TwoGB &&
234 b_grid_desc_n_k.GetElementSpaceSize() *
sizeof(ABDataType) <= TwoGB &&
235 e_grid_desc_m_n.GetElementSpaceSize() *
sizeof(EDataType) <= TwoGB))
245 const index_t num_loop = K / KPerBlock;
247 return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
251 __host__ __device__
static constexpr
auto
254 const auto M = e_grid_desc_m_n.GetLength(
I0);
255 const auto N = e_grid_desc_m_n.GetLength(
I1);
260 const auto M0 = M / M1;
261 const auto N0 = N / N1;
263 constexpr
auto M01 =
I1;
264 constexpr
auto N01 =
I1;
266 const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
273 const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
279 const auto cblockid_to_m0_n0_block_cluster_adaptor =
281 cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
283 return cblockid_to_m0_n0_block_cluster_adaptor;
286 __host__ __device__
static constexpr
index_t
289 const auto M = e_grid_desc_m_n.GetLength(
I0);
290 const auto N = e_grid_desc_m_n.GetLength(
I1);
292 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
298 __host__ __device__
static constexpr
auto
301 const auto M = a_grid_desc_m_k.GetLength(
I0);
302 const auto K = a_grid_desc_m_k.GetLength(
I1);
304 const auto AK0 = K /
AK1;
314 __host__ __device__
static constexpr
auto
317 const auto N = b_grid_desc_n_k.GetLength(
I0);
318 const auto K = b_grid_desc_n_k.GetLength(
I1);
320 const auto BK0 = K /
BK1;
330 template <
typename EGr
idDescriptor_M_N>
332 const EGridDescriptor_M_N& e_grid_desc_m_n)
334 const auto M = e_grid_desc_m_n.GetLength(
I0);
335 const auto N = e_grid_desc_m_n.GetLength(
I1);
337 const auto MBlock = M / MPerBlock;
338 const auto NBlock = N / NPerBlock;
347 return e_grid_desc_mblock_mperblock_nblock_nperblock;
357 template <
bool HasMainKBlockLoop,
358 typename AGridDesc_AK0_M_AK1,
359 typename BGridDesc_BK0_N_BK1,
360 typename Block2ETileMap>
361 __device__
static void Run(
const ABDataType* __restrict__ p_a_grid,
362 const ABDataType* __restrict__ p_b_grid,
363 EDataType* __restrict__ p_e_grid,
364 void* __restrict__ p_shared,
365 const AElementwiseOperation& a_element_op,
366 const BElementwiseOperation& b_element_op,
367 const EElementwiseOperation& e_element_op,
368 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
369 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
371 e_grid_desc_mblock_mperblock_nblock_nperblock,
372 const Block2ETileMap& block_2_etile_map)
388 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
390 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
391 static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
393 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
394 static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
395 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
400 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
401 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
405 const auto block_work_idx =
409 const index_t m_block_data_idx_on_grid =
410 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
412 const index_t n_block_data_idx_on_grid =
413 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
419 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
420 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
421 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
422 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
425 auto a_blockwise_copy =
427 AElementwiseOperation,
431 ABlockTransferThreadClusterLengths_AK0_M_AK1,
432 ABlockTransferThreadClusterArrangeOrder,
435 decltype(a_grid_desc_ak0_m_ak1),
436 decltype(a_block_desc_ak0_m_ak1),
437 ABlockTransferSrcAccessOrder,
439 ABlockTransferSrcVectorDim,
441 ABlockTransferSrcScalarPerVector,
442 ABlockTransferDstScalarPerVector_AK1,
445 AThreadTransferSrcResetCoordinateAfterRun,
447 NumGemmKPrefetchStage>(
448 a_grid_desc_ak0_m_ak1,
451 a_block_desc_ak0_m_ak1,
456 auto b_blockwise_copy =
458 BElementwiseOperation,
462 BBlockTransferThreadClusterLengths_BK0_N_BK1,
463 BBlockTransferThreadClusterArrangeOrder,
466 decltype(b_grid_desc_bk0_n_bk1),
467 decltype(b_block_desc_bk0_n_bk1),
468 BBlockTransferSrcAccessOrder,
470 BBlockTransferSrcVectorDim,
472 BBlockTransferSrcScalarPerVector,
473 BBlockTransferDstScalarPerVector_BK1,
476 BThreadTransferSrcResetCoordinateAfterRun,
478 NumGemmKPrefetchStage>(
479 b_grid_desc_bk0_n_bk1,
482 b_block_desc_bk0_n_bk1,
486 GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
487 a_grid_desc_ak0_m_ak1,
488 a_block_desc_ak0_m_ak1,
492 a_block_slice_copy_step,
493 b_grid_desc_bk0_n_bk1,
494 b_block_desc_bk0_n_bk1,
498 b_block_slice_copy_step,
499 num_k_block_main_loop);
508 constexpr
bool is_single_rate_mfma =
516 constexpr
auto is_scale_mfma =
false;
524 is_scale_mfma>::selected_mfma.k_per_blk);
527 TileMathThreadGroupSize,
531 decltype(a_block_desc_ak0_m_ak1),
532 decltype(b_block_desc_bk0_n_bk1),
540 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
541 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
545 GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
546 a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
558 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
559 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
562 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
563 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
566 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
567 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
571 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
572 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
574 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
575 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
576 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
577 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
578 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
579 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
580 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
581 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
583 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
586 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
587 static_cast<EDataTypeShuffle*
>(p_shared),
588 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
591 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
613 const auto c_thread_mtx_on_block =
614 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
616 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
617 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
619 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
625 const auto m_thread_data_on_block_idx =
626 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
629 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
635 const auto n_thread_data_on_block_idx =
636 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
643 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
644 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
646 Sequence<CShuffleMXdlPerWavePerShuffle,
647 CShuffleNXdlPerWavePerShuffle,
659 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
662 m_thread_data_on_block_idx[
I1],
663 n_thread_data_on_block_idx[
I1],
664 m_thread_data_on_block_idx[
I2],
665 m_thread_data_on_block_idx[
I3],
666 m_thread_data_on_block_idx[
I4],
667 n_thread_data_on_block_idx[
I2]),
673 EElementwiseOperation,
674 CGlobalMemoryDataOperation,
676 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
678 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
679 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
683 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
684 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
687 CShuffleBlockTransferScalarPerVector_NPerBlock,
690 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
692 e_grid_desc_mblock_mperblock_nblock_nperblock,
697 constexpr
auto sfc_c_vgpr =
700 Sequence<CShuffleMXdlPerWavePerShuffle,
701 CShuffleNXdlPerWavePerShuffle,
710 constexpr
auto sfc_c_global =
714 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
716 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
718 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
720 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
749 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
750 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
752 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
753 c_shuffle_block_buf);
758 c_shuffle_block_copy_lds_to_global.Run(
759 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
761 e_grid_desc_mblock_mperblock_nblock_nperblock,
764 if constexpr(access_id < num_access - 1)
766 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
769 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
770 e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:91
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:86
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:84
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:106
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:99
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:101
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
static constexpr auto BlockSize
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:80
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:252
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:133
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:148
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:331
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:315
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:194
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:116
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:124
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:109
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:355
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:361
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:243
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:287
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:352
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
static __device__ constexpr bool IsValidCompilationParameter()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:299
Definition: gridwise_gemm_waveletmodel.hpp:11
Definition: gridwise_gemm_waveletmodel.hpp:103
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334