22 template <
typename Problem>
27 constexpr
index_t data_bytes =
sizeof(
typename Problem::ADataType);
28 static_assert(copy_bytes % data_bytes == 0);
29 return copy_bytes / data_bytes;
32 template <
typename Problem>
35 constexpr
index_t copy_bytes = [&]() {
return 16; }();
36 constexpr
index_t data_bytes =
sizeof(
typename Problem::GDataType);
37 static_assert(copy_bytes % data_bytes == 0);
38 return copy_bytes / data_bytes;
41 template <
typename Problem>
44 constexpr
index_t copy_bytes = [&]() {
return 16; }();
45 constexpr
index_t data_bytes =
sizeof(
typename Problem::DDataType);
46 static_assert(copy_bytes % data_bytes == 0);
47 return copy_bytes / data_bytes;
50 template <
typename Problem>
53 if constexpr(Problem::Traits::OAtomic == 1)
56 static_assert(
sizeof(
typename Problem::ODataType) == 2);
59 else if constexpr(Problem::Traits::OAtomic == 2)
66 return 16 /
sizeof(
typename Problem::ODataType);
70 template <
typename DataType_>
77 template <
typename Problem>
80 return GetSmemKPack<typename Problem::ADataType>();
84 template <
typename Problem>
88 return 16 /
sizeof(
typename Problem::YDataType);
91 template <
typename Problem>
94 constexpr
auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
95 constexpr
auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
96 static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
97 return a_sld_desc.get_element_space_size();
100 template <
typename Problem>
103 constexpr
auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
104 constexpr
auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
105 static_assert(bridge_sld_desc.get_element_space_size() ==
106 bridge_sst_desc.get_element_space_size());
107 return bridge_sld_desc.get_element_space_size();
110 template <
typename Problem>
113 constexpr
index_t a_lds = GetSmemSize_A<Problem>();
114 constexpr
index_t bridge_lds = GetSmemSize_Bridge<Problem>();
115 return max(a_lds, bridge_lds);
118 template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
121 constexpr
index_t K_vec = Alignment;
122 constexpr
index_t K_rem = KPerBlock / K_vec;
129 static_assert(K_wav <= NumWarps,
"not not support thread has repeat along K yet");
130 constexpr
index_t M_wav = NumWarps / K_wav;
131 static_assert(MPerBlock % M_wav == 0,
"this tile size is too small please check");
132 constexpr
index_t M_rep = MPerBlock / M_wav;
145 constexpr
index_t K_lan = K_rem;
147 constexpr
index_t M_wav = NumWarps;
148 static_assert(MPerBlock % (M_lan * M_wav) == 0,
149 "this tile size is too small please check");
150 constexpr
index_t M_rep = MPerBlock / (M_lan * M_wav);
163 template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
166 constexpr
index_t K_vec = Alignment;
167 constexpr
index_t K_rem = KPerBlock / K_vec;
174 static_assert(K_wav <= NumWarps,
"do not support thread has repeat along K yet");
175 constexpr
index_t M_wav = NumWarps / K_wav;
176 static_assert(MPerBlock % M_wav == 0,
"this tile size is too small please check");
177 constexpr
index_t M_rep = MPerBlock / M_wav;
190 constexpr
index_t K_lan = K_rem;
192 constexpr
index_t M_wav = NumWarps;
193 static_assert(MPerBlock % (M_lan * M_wav) == 0,
194 "this tile size is too small please check");
195 constexpr
index_t M_rep = MPerBlock / (M_lan * M_wav);
210 template <
index_t WarpPerBlock_N_,
229 template <
typename Problem>
232 constexpr
index_t Block_M_ = Problem::BlockShape::Block_M0;
233 constexpr
index_t Block_K_ = Problem::BlockShape::Block_K0;
234 constexpr
index_t NumWarps_ = Problem::BlockShape::NumWarps;
235 constexpr
index_t Alignment_ = GetAlignment_A<Problem>();
242 template <
typename Problem>
245 constexpr
auto PermuteEnum = Problem::Traits::PermuteEnum;
247 using S_ =
typename Problem::BlockShape;
257 GetAlignment_G<Problem>()>();
261 template <
typename Problem>
264 constexpr
auto PermuteEnum = Problem::Traits::PermuteEnum;
265 using S_ =
typename Problem::BlockShape;
273 GetAlignment_D<Problem>()>();
277 template <
typename Problem>
284 constexpr
auto c_block_outer_dstr_encoding =
294 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
299 template <
typename Problem>
303 constexpr
index_t Block_M = Problem::BlockShape::Block_M0;
304 constexpr
index_t Block_K = Problem::BlockShape::Block_K0;
307 constexpr
index_t NumWarps = Problem::BlockShape::NumWarps;
309 constexpr
index_t KPack = GetSmemKPack_A<Problem>();
310 constexpr
index_t KVector = GetAlignment_A<Problem>();
311 constexpr
index_t KPad = KPack;
313 static_assert(Block_K % KVector == 0);
314 constexpr
index_t LanesPerK = Block_K / KVector;
315 if constexpr(LanesPerK >= WarpSize)
318 static_assert(LanesPerK % WarpSize == 0);
319 constexpr
index_t wavesPerK = LanesPerK / WarpSize;
320 if constexpr(wavesPerK > NumWarps)
326 constexpr
index_t wavesPerM = NumWarps / wavesPerK;
327 constexpr
index_t NumIssues = Block_M / wavesPerM;
335 number<wavesPerK*(WarpSize * KVector + KPad)>{},
351 return lds_block_desc_issues_warps_lanes;
357 static_assert(WarpSize % LanesPerK == 0);
358 constexpr
index_t LaneGroups = WarpSize / LanesPerK;
359 constexpr
index_t NumIssues = Block_M / (LaneGroups * NumWarps);
384 return lds_block_desc_issues_warps_lanes;
388 template <
typename Problem>
398 constexpr
index_t Block_M = Problem::BlockShape::Block_M0;
399 constexpr
index_t Block_K = Problem::BlockShape::Block_K0;
402 constexpr
index_t NumWarps = Problem::BlockShape::NumWarps;
404 constexpr
index_t KPack = GetSmemKPack_A<Problem>();
405 constexpr
index_t KVector = GetAlignment_A<Problem>();
406 constexpr
index_t KPad = KPack;
408 static_assert(Block_K % KVector == 0);
409 constexpr
index_t LanesPerK = Block_K / KVector;
410 if constexpr(LanesPerK >= WarpSize)
413 static_assert(LanesPerK % WarpSize == 0);
414 constexpr
index_t wavesPerK = LanesPerK / WarpSize;
415 if constexpr(wavesPerK >= NumWarps)
421 constexpr
index_t wavesPerM = NumWarps / wavesPerK;
422 constexpr
index_t NumIssues = Block_M / wavesPerM;
430 number<wavesPerK*(WarpSize * KVector + KPad)>{},
452 static_assert(WarpSize % LanesPerK == 0);
453 constexpr
index_t LaneGroups = WarpSize / LanesPerK;
454 constexpr
index_t NumIssues = Block_M / (LaneGroups * NumWarps);
483 template <
typename Problem>
486 constexpr
index_t Block_M = Problem::BlockShape::Block_M0;
487 constexpr
index_t Block_N = Problem::BlockShape::Block_N0;
489 constexpr
index_t KVector = GetSmemKPack_Y<Problem>();
492 constexpr
auto desc =
500 template <
typename Problem>
503 constexpr
index_t Block_M = Problem::BlockShape::Block_M0;
504 constexpr
index_t Block_N = Problem::BlockShape::Block_N0;
506 constexpr
index_t KVector = GetSmemKPack_Y<Problem>();
509 constexpr
auto desc =
517 template <
typename Problem>
520 constexpr
index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
521 constexpr
index_t Repeat_N = Problem::BlockShape::Repeat_N0;
522 constexpr
index_t Repeat_M = Problem::BlockShape::Repeat_M0;
524 constexpr
index_t kAMLane = 16;
525 constexpr
index_t kABKLane = 4;
526 constexpr
index_t kABKPerLane = 4;
528 constexpr
index_t KPack = kABKPerLane;
559 template <
typename Problem>
562 using S_ =
typename Problem::BlockShape;
567 if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
568 std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
569 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
575 else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
576 std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
577 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
585 template <
typename Problem>
591 using S_ =
typename Problem::BlockShape;
595 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
596 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
597 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
598 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
604 constexpr
auto seq_all =
617 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
618 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
619 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
620 S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
626 constexpr
auto seq_all =
637 template <
typename Problem>
643 using S_ =
typename Problem::BlockShape;
646 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
647 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
648 S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
649 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
655 constexpr
auto seq_all =
668 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
669 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
670 S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
671 S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
677 constexpr
auto seq_all =
688 template <
typename Problem>
691 using S_ =
typename Problem::BlockShape;
694 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
695 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
696 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
702 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
703 std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
704 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
712 template <
typename Problem>
717 using CDataType =
typename WarpGemm::CDataType;
719 constexpr
auto c_block_outer_dstr_encoding =
729 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
731 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
732 return c_block_tensor;
735 template <
typename Problem>
740 using CDataType =
typename WarpGemm::CDataType;
742 constexpr
auto c_block_outer_dstr_encoding =
752 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
754 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
755 return c_block_tensor;
759 template <
typename Problem>
766 constexpr
auto y_outer_dstr_enc =
775 y_outer_dstr_enc,
typename WarpGemm::AWarpDstrEncoding{});
780 template <
typename Problem>
783 constexpr
auto y_block_dstr = MakeYTileDistribution<Problem>();
784 auto y_block_tensor =
785 make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
786 return y_block_tensor;
789 template <
typename Problem>
792 using S_ =
typename Problem::BlockShape;
793 if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
794 std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
795 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
796 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
800 else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
801 std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
802 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
803 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
809 template <
typename Problem>
812 using S_ =
typename Problem::BlockShape;
813 using T_ =
typename Problem::Traits;
814 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
815 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
816 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
817 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
818 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
819 T_::PipeInterleave ==
false)
824 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
825 std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
826 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
827 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
828 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
829 T_::PipeInterleave ==
false)
834 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
835 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
836 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
837 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
838 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
839 T_::PipeInterleave ==
true)
844 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
845 std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
846 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
847 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
848 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
849 T_::PipeInterleave ==
true)
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:401
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:540
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:15
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:262
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:736
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:278
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_Y()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:85
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:71
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:33
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:501
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:92
static constexpr CK_TILE_HOST_DEVICE auto MakeYTileDistribution()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:760
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_Bridge()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:101
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK_Async()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:164
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:119
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:560
static constexpr CK_TILE_HOST_DEVICE index_t GetAsyncCopyDwords()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:16
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:51
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:42
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreForUKDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:518
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:23
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:586
static constexpr CK_TILE_HOST_DEVICE auto MakeYBlockTile()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:781
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:300
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:638
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsLoadDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:484
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:111
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:389
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:713
static constexpr CK_TILE_HOST_DEVICE auto GetUK_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:810
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:78
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_Nr_Kr_W()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:216
static constexpr CK_TILE_HOST_DEVICE auto GetUK_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:790
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:243
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:689
Definition: warp_gemm_attribute_mfma_impl.hpp:1596
Definition: warp_gemm_attribute_mfma_impl.hpp:448
Definition: warp_gemm_attribute_mfma.hpp:701
Definition: warp_gemm_impl.hpp:11
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192