/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp Source File
gridwise_gemm_xdlops_bwd_weight.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
21 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
22 // will be very bad
23 template <typename LowLengths>
25 {
26  static constexpr index_t NDimLow = LowLengths::Size();
27 
30 
32  decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
33 
34  using UpLengths =
35  decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
36 
37  LowLengths low_lengths_;
40 
41  __host__ __device__ constexpr Merge_v4_no_carry() = default;
42 
43  __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths)
44  : low_lengths_{low_lengths},
46  container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
47  up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
48  {
49  static_assert(LowerIndex::Size() == NDimLow, "wrong!");
50  }
51 
52  __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
53 
54  __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
55 
56  __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
57 
58  template <typename LowIdx, typename UpIdx>
59  __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
60  const UpIdx& idx_up) const
61  {
62  static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
63  "wrong! inconsistent # of dimension");
64 
65  index_t tmp = idx_up[Number<0>{}];
66 
67  // division and mod
68  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
69  idx_low(i) = tmp / this->low_lengths_scan_[i];
70  tmp %= this->low_lengths_scan_[i];
71  });
72 
73  idx_low(Number<NDimLow - 1>{}) = tmp;
74  }
75 
76  template <typename LowIdxDiff,
77  typename UpIdxDiff,
78  typename LowIdx,
79  typename UpIdx,
80  index_t Hack>
81  __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
82  const UpIdxDiff& idx_up_diff,
83  LowIdx& idx_low,
84  const UpIdx& idx_up_new,
85  Number<Hack>) const
86  {
87  static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
88  LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
89  "wrong! inconsistent # of dimension");
90 
91  constexpr auto I0 = Number<0>{};
92  constexpr auto INm1 = Number<NDimLow - 1>{};
93 
94  index_t tmp = idx_up_new[I0];
95 
96  idx_low(INm1) = tmp;
97  idx_diff_low(INm1) = idx_up_diff[I0];
98  }
99 
100  __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
101 
102  __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
103  {
104  return true;
105  }
106 
107  __host__ __device__ static constexpr bool IsKnownAtCompileTime()
108  {
112  }
113 
114  template <typename UpIdx>
115  __host__ __device__ static constexpr bool
116  IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
117  {
118  return true;
119  }
120 
121  __host__ __device__ void Print() const
122  {
123  printf("{");
124  printf("Merge_v3_direct_division_mod_wrw, ");
125  printf("low_lengths_ ");
126  print_multi_index(low_lengths_);
127  printf("low_lengths_scan_ ");
128  print_multi_index(low_lengths_scan_);
129  printf("up_lengths_ ");
130  print_multi_index(up_lengths_);
131  printf("}");
132  }
133 };
134 
135 template <typename LowLengths>
136 __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths)
137 {
138  return Merge_v4_no_carry<LowLengths>{low_lengths};
139 }
140 
141 template <typename GridwiseGemm,
142  typename FloatA,
143  typename FloatB,
144  typename FloatC,
145  typename AGridDesc_B_K0_M_K1,
146  typename BGridDesc_B_K0_N_K1,
147  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
148  typename AElementwiseOperation,
149  typename BElementwiseOperation,
150  typename CElementwiseOperation,
151  typename CBlockClusterAdaptor,
152  bool HasMainKBlockLoop>
153 __global__ void
154 #if CK_USE_LAUNCH_BOUNDS
156 #endif
157  kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
158  const FloatB* __restrict__ p_b_grid,
159  FloatC* __restrict__ p_c_grid,
160  const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
161  const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
162  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
163  c_grid_desc_mblock_mperblock_nblock_nperblock,
164  const AElementwiseOperation a_element_op,
165  const BElementwiseOperation b_element_op,
166  const CElementwiseOperation c_element_op,
167  const CBlockClusterAdaptor c_block_cluster_adaptor)
168 {
169 #if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__))
170  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
171 
172  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
173  p_b_grid,
174  p_c_grid,
175  p_shared,
176  a_b_k0_m_k1_grid_desc,
177  b_b_k0_n_k1_grid_desc,
178  c_grid_desc_mblock_mperblock_nblock_nperblock,
179  a_element_op,
180  b_element_op,
181  c_element_op,
182  c_block_cluster_adaptor);
183 #else
184  ignore = p_a_grid;
185  ignore = p_b_grid;
186  ignore = p_c_grid;
187  ignore = a_b_k0_m_k1_grid_desc;
188  ignore = b_b_k0_n_k1_grid_desc;
189  ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
190  ignore = a_element_op;
191  ignore = b_element_op;
192  ignore = c_element_op;
193  ignore = c_block_cluster_adaptor;
194 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
195 }
196 
197 template <index_t BlockSize,
198  typename FloatA,
199  typename FloatB,
200  typename FloatAcc,
201  typename FloatC,
202  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
203  typename AGridDesc_B_K0_M_K1,
204  typename BGridDesc_B_K0_N_K1,
205  typename CMNGridDesc,
206  typename AElementwiseOperation,
207  typename BElementwiseOperation,
208  typename CElementwiseOperation,
209  index_t MPerBlock,
210  index_t NPerBlock,
211  index_t K0PerBlock,
212  index_t MPerXDL,
213  index_t NPerXDL,
214  index_t K1Value,
215  index_t MRepeat,
216  index_t NRepeat,
217  typename ABlockTransferThreadClusterLengths_K0_M_K1,
218  typename ABlockTransferThreadClusterArrangeOrder,
219  typename ABlockTransferSrcAccessOrder,
220  index_t ABlockTransferSrcVectorDim,
221  index_t ABlockTransferSrcScalarPerVector,
222  index_t ABlockTransferDstScalarPerVector_K1,
223  bool AThreadTransferSrcResetCoordinateAfterRun,
224  bool ABlockLdsExtraM,
225  index_t ABlockLdsM1PerBlock,
226  index_t ABlockLdsM0PerBlock,
227  index_t ABlockLdsM1Padding,
228  typename BBlockTransferThreadClusterLengths_K0_N_K1,
229  typename BBlockTransferThreadClusterArrangeOrder,
230  typename BBlockTransferSrcAccessOrder,
231  index_t BBlockTransferSrcVectorDim,
232  index_t BBlockTransferSrcScalarPerVector,
233  index_t BBlockTransferDstScalarPerVector_K1,
234  bool BThreadTransferSrcResetCoordinateAfterRun,
235  bool BBlockLdsExtraN,
236  index_t BBlockLdsN1PerBlock,
237  index_t BBlockLdsN0PerBlock,
238  index_t BBlockLdsN1Padding,
239  index_t CShuffleMRepeatPerShuffle,
240  index_t CShuffleNRepeatPerShuffle,
241  index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
242  typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
243  bool ABlockLdsExtraM1Wrw = false,
244  bool BBlockLdsExtraN1Wrw = false,
245  index_t NumGemmKPrefetchStage = 1,
246  PipelineVersion PipelineVer = PipelineVersion::v1,
247  typename ComputeTypeA = FloatA,
248  typename ComputeTypeB = ComputeTypeA>
250 {
251  static constexpr auto I0 = Number<0>{};
252  static constexpr auto I1 = Number<1>{};
253  static constexpr auto I2 = Number<2>{};
254  static constexpr auto I3 = Number<3>{};
255  static constexpr auto I4 = Number<4>{};
256  static constexpr auto I5 = Number<5>{};
257  static constexpr auto I6 = Number<6>{};
258  static constexpr auto I7 = Number<7>{};
259 
260  // K1 should be Number<...>
261  static constexpr auto K1 = Number<K1Value>{};
262 
264 
266  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
267 
268  // denorm test fix, required to work around fp16 mfma issue
269  // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
270  // when mfma if fixed, remove this section and update
271  // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
272  // throughout this file
273 #if CK_GFX90A_DENORM_WORKAROUND
274  using FloatAAdjusted =
276  using FloatBAdjusted =
278 #else
279  using FloatAAdjusted = ComputeTypeA;
280  using FloatBAdjusted = ComputeTypeB;
281 #endif
282 
283  // M0/M1/M1Padding
284  static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
285  static constexpr auto M0PerBlock = Number<ABlockLdsM0PerBlock>{};
286  static constexpr auto M1Padding = Number<ABlockLdsM1Padding>{};
287 
288  // N0/N1/N1Padding
289  static constexpr auto N1PerBlock = Number<BBlockLdsN1PerBlock>{};
290  static constexpr auto N0PerBlock = Number<BBlockLdsN0PerBlock>{};
291  static constexpr auto N1Padding = Number<BBlockLdsN1Padding>{};
292 
293  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
294  {
295  constexpr auto max_lds_align = K1;
296 
297  // A matrix in LDS memory, dst of blockwise copy
298  constexpr auto a_block_desc_k0_m_k1 = [&]() {
299  if constexpr(ABlockLdsExtraM)
300  {
301  if constexpr(ABlockLdsExtraM1Wrw)
302  {
303  constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
304  make_tuple(
306  make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
307  Number<M1PerBlock>{} * K1 + M1Padding,
308  K1,
309  I1));
310 
311  constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
312  a_block_desc_k0_m0_m1_k1,
319 
320  return a_block_desc_k0_m_k1_tmp;
321  }
322  else
323  {
326  make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
327  }
328  }
329  else
330  {
332  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
333  }
334  }();
335 
336  return a_block_desc_k0_m_k1;
337  }
338 
339  __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
340  {
341  constexpr auto max_lds_align = K1;
342 
343  // A matrix in LDS memory, dst of blockwise copy
344  constexpr auto a_block_desc_b_k0_m_k1 = [&]() {
345  if constexpr(ABlockLdsExtraM)
346  {
347  if constexpr(ABlockLdsExtraM1Wrw)
348  {
349  constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
354  K1),
356  (Number<M1PerBlock>{} * K1 + M1Padding),
357  Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
358  Number<M1PerBlock>{} * K1 + M1Padding,
359  K1,
360  I1));
361 
362  constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
363  a_block_desc_b_k0_m0_m1_k1,
371 
372  return a_block_desc_b_k0_m_k1_tmp;
373  }
374  else
375  {
379  Number<MPerBlock + 1>{} * K1,
380  K1,
381  I1));
382  }
383  }
384  else
385  {
388  max_lds_align);
389  }
390  }();
391 
392  return a_block_desc_b_k0_m_k1;
393  }
394 
395  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
396  {
397  constexpr auto max_lds_align = K1;
398 
399  // B matrix in LDS memory, dst of blockwise copy
400  constexpr auto b_block_desc_k0_n_k1 = [&]() {
401  if constexpr(BBlockLdsExtraN)
402  {
403  if constexpr(BBlockLdsExtraN1Wrw)
404  {
405  constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
406  make_tuple(
408  make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
409  Number<N1PerBlock>{} * K1 + N1Padding,
410  K1,
411  I1));
412 
413  constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
414  b_block_desc_k0_n0_n1_k1,
421 
422  return b_block_desc_k0_n_k1_tmp;
423  }
424  else
425  {
428  make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
429  }
430  }
431  else
432  {
434  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
435  }
436  }();
437 
438  return b_block_desc_k0_n_k1;
439  }
440 
441  __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
442  {
443  constexpr auto max_lds_align = K1;
444 
445  // B matrix in LDS memory, dst of blockwise copy
446  constexpr auto b_block_desc_b_k0_n_k1 = [&]() {
447  if constexpr(BBlockLdsExtraN)
448  {
449  if constexpr(BBlockLdsExtraN1Wrw)
450  {
451  constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
456  K1),
458  (Number<N1PerBlock>{} * K1 + N1Padding),
459  Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
460  Number<N1PerBlock>{} * K1 + N1Padding,
461  K1,
462  I1));
463 
464  constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
465  b_block_desc_b_k0_n0_n1_k1,
473 
474  return b_block_desc_b_k0_n_k1_tmp;
475  }
476  else
477  {
481  Number<NPerBlock + 1>{} * K1,
482  K1,
483  I1));
484  }
485  }
486  else
487  {
490  max_lds_align);
491  }
492  }();
493 
494  return b_block_desc_b_k0_n_k1;
495  }
496 
497  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
498  {
499  constexpr auto max_lds_align = K1;
500 
501  // A matrix in LDS memory, dst of blockwise copy
502  constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
503 
504  // B matrix in LDS memory, dst of blockwise copy
505  constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
506 
507  // LDS allocation for A and B: be careful of alignment
508  constexpr auto a_block_space_size = math::integer_least_multiple(
509  a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
510 
511  constexpr auto b_block_space_size = math::integer_least_multiple(
512  b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
513 
514  constexpr auto c_block_size =
515  GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
516 
517  return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
518  b_block_space_size * sizeof(FloatBAdjusted)),
519  c_block_size * sizeof(FloatC));
520  }
521 
522  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
523  template <typename Block2CTileMap>
524  __host__ __device__ static constexpr bool
525  CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
526  const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
527  const CMNGridDesc& c_m_n_grid_desc,
528  const Block2CTileMap& block_2_ctile_map)
529  {
530  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
531  "wrong! K1 need to be known at compile-time");
532 
533  static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
534  (NPerBlock % (NRepeat * NPerXDL)) == 0,
535  "Invalid tuning param!");
536 
537  const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
538  const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
539  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
540  const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
541 
542  // check gridwise gemm pipeline
543  const auto num_k_loop = K0 / K0PerBlock;
544 
545  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
546  {
547  return false;
548  }
549 
550  if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
551  K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
552  K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
553  K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
554  KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
555  return false;
556 
557  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
558  return false;
559 
560  if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
561  {
562  return false;
563  }
564 
565  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
566  return true;
567  }
568 
569  __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
570  {
571  // const bool has_main_k0_block_loop = K0 > K0PerBlock;
572  const index_t num_loop = K0 / K0PerBlock;
573 
574  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
575 
576  // return has_main_k0_block_loop;
577  }
578 
579  __host__ __device__ static constexpr auto
580  MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
581  {
582  const auto M = c_m_n_grid_desc.GetLength(I0);
583  const auto N = c_m_n_grid_desc.GetLength(I1);
584 
585  const auto MBlock = M / MPerBlock;
586  const auto NBlock = N / NPerBlock;
587 
589  c_m_n_grid_desc,
594  }
595 
596  // return block_id to C matrix tile idx (m0, n0) mapping
597  __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
598  const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
599  {
601  c_m_n_grid_desc, M01, N01, KBatch);
602  }
603 
604  __host__ __device__ static constexpr auto
606  {
607  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
608  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
609 
611  make_tuple(I1,
613  I1,
615  }
616 
618  decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
619  using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
620 
621  template <bool HasMainKBlockLoop>
622  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
623  const FloatB* __restrict__ p_b_grid,
624  FloatC* __restrict__ p_c_grid,
625  void* __restrict__ p_shared,
626  const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
627  const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
629  c_grid_desc_mblock_mperblock_nblock_nperblock,
630  const AElementwiseOperation& a_element_op,
631  const BElementwiseOperation& b_element_op,
632  const CElementwiseOperation& c_element_op,
633  const CBlockClusterAdaptor& c_block_cluster_adaptor)
634  {
635  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
636  p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
637  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
638  p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
639  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
640  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
641 
642  const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
643 
644  // divide block work by [M, N]
645  const auto block_work_idx =
646  c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
647 
648  const index_t k_batch_id = block_work_idx[I0];
649 
650  if(!c_block_cluster_adaptor.ValidCTileIndex(
651  make_tuple(block_work_idx[I1], block_work_idx[I2]),
652  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
653  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
654  {
655  return;
656  }
657 
658  // HACK: this force m/n_block_data_idx_on_grid into SGPR
659  const index_t m_block_data_idx_on_grid =
660  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
661 
662  const index_t n_block_data_idx_on_grid =
663  __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
664 
665  // lds max alignment
666  constexpr auto max_lds_align = K1;
667 
668  // A matrix in LDS memory, dst of blockwise copy
669  constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
670 
671  constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1();
672  // B matrix in LDS memory, dst of blockwise copy
673  constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
674 
675  constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
676  // A matrix blockwise copy
677  auto a_blockwise_copy =
679  AElementwiseOperation,
681  InMemoryDataOperationEnum::Set,
683  ABlockTransferThreadClusterLengths_K0_M_K1,
684  ABlockTransferThreadClusterArrangeOrder,
685  FloatA,
687  decltype(a_b_k0_m_k1_grid_desc),
688  decltype(a_b_k0_m_k1_block_desc),
689  ABlockTransferSrcAccessOrder,
691  ABlockTransferSrcVectorDim,
692  3,
693  ABlockTransferSrcScalarPerVector,
694  ABlockTransferDstScalarPerVector_K1,
695  1,
696  1,
697  AThreadTransferSrcResetCoordinateAfterRun,
698  true>(
699  a_b_k0_m_k1_grid_desc,
700  make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
701  a_element_op,
702  a_b_k0_m_k1_block_desc,
703  make_multi_index(0, 0, 0, 0),
705 
706  // B matrix blockwise copy
707  auto b_blockwise_copy =
709  BElementwiseOperation,
711  InMemoryDataOperationEnum::Set,
713  BBlockTransferThreadClusterLengths_K0_N_K1,
714  BBlockTransferThreadClusterArrangeOrder,
715  FloatB,
717  decltype(b_b_k0_n_k1_grid_desc),
718  decltype(b_b_k0_n_k1_block_desc),
719  BBlockTransferSrcAccessOrder,
721  BBlockTransferSrcVectorDim,
722  3,
723  BBlockTransferSrcScalarPerVector,
724  BBlockTransferDstScalarPerVector_K1,
725  1,
726  1,
727  BThreadTransferSrcResetCoordinateAfterRun,
728  true>(
729  b_b_k0_n_k1_grid_desc,
730  make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
731  b_element_op,
732  b_b_k0_n_k1_block_desc,
733  make_multi_index(0, 0, 0, 0),
735 
736  // GEMM definition
737  // c_mtx += transpose(a_mtx) * b_mtx
738  // a_mtx[K0PerBlock, MPerBlock] is in LDS
739  // b_mtx[K0PerBlock, NPerBlock] is in LDS
740  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
741  // register
742  // sanity check
743  constexpr bool is_single_rate_mfma =
745  K1 <= 4) ||
748  K1 < 32))
749  ? true
750  : false;
751  constexpr auto is_scale_mfma = false;
752  constexpr index_t KPack = math::max(K1,
754  MPerXDL,
755  NPerXDL,
757  is_single_rate_mfma,
758  is_scale_mfma>::selected_mfma.k_per_blk);
759 
760  auto blockwise_gemm =
764  FloatAcc,
765  decltype(a_k0_m_k1_block_desc),
766  decltype(b_k0_n_k1_block_desc),
767  MPerXDL,
768  NPerXDL,
769  MRepeat,
770  NRepeat,
771  KPack>{};
772 
773  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
774 
775  // LDS allocation for A and B: be careful of alignment
776  constexpr auto a_block_space_size =
777  math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
778 
779  constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
780  constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
781 
782  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
783  static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
784 
785  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
786  static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
787  b_k0_n_k1_block_desc.GetElementSpaceSize());
788 
789  // gridwise GEMM pipeline
790  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
791 
792  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
793  a_b_k0_m_k1_block_desc,
794  a_blockwise_copy,
795  a_grid_buf,
796  a_block_buf,
797  a_block_slice_copy_step,
798  b_b_k0_n_k1_grid_desc,
799  b_b_k0_n_k1_block_desc,
800  b_blockwise_copy,
801  b_grid_buf,
802  b_block_buf,
803  b_block_slice_copy_step,
804  blockwise_gemm,
805  c_thread_buf,
806  K0BlockMainLoop);
807 
808  // output: register to global memory
809  {
810  constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
811  constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
812 
813  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
814  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
815 
816  constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
817  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
818 
819  constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
820  constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
821  constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
822  constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
823  constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
824  constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
825  constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
826  constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
827 
828  constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
829  GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
830 
831  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
832  static_cast<FloatC*>(p_shared),
833  c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
834 
835  static_assert(M1 == MWave, "");
836  static_assert(N1 == NWave, "");
837  static_assert(M2 * M3 * M4 == MPerXDL, "");
838  static_assert(N2 == NPerXDL, "");
839 
840  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
841  c_block_desc_mblock_mperblock_nblock_nperblock,
842  make_tuple(
843  make_freeze_transform(I0), // freeze mblock
844  make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
845  M1,
846  M2,
847  M3,
848  M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
849  make_freeze_transform(I0), // freeze nblock
850  make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
851  N1,
852  N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
854  make_tuple(
856 
857  // calculate origin of thread output tensor on global memory
858  // blockwise GEMM c matrix starting index
859  const auto c_thread_mtx_on_block =
860  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
861 
862  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
863  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
864 
865  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
867  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
870 
871  const auto m_thread_data_on_block_idx =
872  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
873  make_multi_index(m_thread_data_on_block));
874 
875  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
880 
881  const auto n_thread_data_on_block_idx =
882  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
883  make_multi_index(n_thread_data_on_block));
884 
885  // VGPR to LDS
886  auto c_thread_copy_vgpr_to_lds =
888  FloatC,
889  decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
890  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
892  Sequence<CShuffleMRepeatPerShuffle,
893  CShuffleNRepeatPerShuffle,
894  I1,
895  I1,
896  M2,
897  I1,
898  M4,
899  I1>,
901  7,
902  1,
903  InMemoryDataOperationEnum::Set,
904  1,
905  true>{
906  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
908  0,
909  m_thread_data_on_block_idx[I1],
910  n_thread_data_on_block_idx[I1],
911  m_thread_data_on_block_idx[I2],
912  m_thread_data_on_block_idx[I3],
913  m_thread_data_on_block_idx[I4],
914  n_thread_data_on_block_idx[I2]),
916 
917  // LDS to global
918  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
919  ThisThreadBlock, // index_t BlockSize,
920  CElementwiseOperation, // ElementwiseOperation,
921  CGlobalMemoryDataOperation, // DstInMemOp,
922  Sequence<1,
923  CShuffleMRepeatPerShuffle * MWave * MPerXDL,
924  1,
925  CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
926  CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
927  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
928  FloatC, // typename SrcData,
929  FloatC, // typename DstData,
930  decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
931  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
932  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
933  3, // index_t VectorDim,
934  CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
935  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
936  false> // bool ThreadTransferDstResetCoordinateAfterRun
937  {c_block_desc_mblock_mperblock_nblock_nperblock,
938  make_multi_index(0, 0, 0, 0),
939  c_grid_desc_mblock_mperblock_nblock_nperblock,
940  make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
941  c_element_op};
942 
943  constexpr auto mxdlperwave_forward_step =
944  make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
945  constexpr auto nxdlperwave_forward_step =
946  make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
947  constexpr auto nxdlperwave_backward_step =
948  make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
949 
950  static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
951  constexpr auto mxdlperwave = mxdlperwave_iter;
952 
953  static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
954  constexpr bool nxdlperwave_forward_sweep =
955  (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
956 
957  constexpr index_t nxdlperwave_value =
958  nxdlperwave_forward_sweep
959  ? nxdlperwave_iter
960  : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
961 
962  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
963 
964  // make sure it's safe to do ds_write
965  block_sync_lds();
966 
967  // VGPR to LDS
968  c_thread_copy_vgpr_to_lds.Run(
969  c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
970  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
971  c_thread_buf,
972  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
973  c_block_buf);
974 
975  // make sure it's safe to do ds_read
976  block_sync_lds();
977 
978  // LDS to global
979  c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
980  c_block_buf,
981  c_grid_desc_mblock_mperblock_nblock_nperblock,
982  c_grid_buf);
983 
984  // move on nxdlperwave dimension
985  if constexpr(nxdlperwave_forward_sweep &&
986  (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
987  {
988  c_block_copy_lds_to_global.MoveDstSliceWindow(
989  c_grid_desc_mblock_mperblock_nblock_nperblock,
990  nxdlperwave_forward_step);
991  }
992  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
993  {
994  c_block_copy_lds_to_global.MoveDstSliceWindow(
995  c_grid_desc_mblock_mperblock_nblock_nperblock,
996  nxdlperwave_backward_step);
997  }
998  });
999 
1000  // move on mxdlperwave dimension
1001  if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1002  {
1003  c_block_copy_lds_to_global.MoveDstSliceWindow(
1004  c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1005  }
1006  });
1007  }
1008  }
1009 }; // namespace ck
1010 
1011 } // namespace ck
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const CBlockClusterAdaptor c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:157
InMemoryDataOperationEnum
Definition: ck.hpp:276
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition: statically_indexed_array_multi_index.hpp:147
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__host__ constexpr __device__ auto make_merge_transform_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:136
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: array.hpp:14
Definition: block_to_ctile_map.hpp:719
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:80
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:250
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:266
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:395
ComputeTypeB FloatBAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:280
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:263
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:497
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:618
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:293
ComputeTypeA FloatAAdjusted
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:279
__host__ static constexpr __device__ auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:339
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:605
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:569
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CMNGridDesc &c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:597
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const CBlockClusterAdaptor &c_block_cluster_adaptor)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:622
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_B_K0_M_K1 &a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 &b_b_k0_n_k1_grid_desc, const CMNGridDesc &c_m_n_grid_desc, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:525
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:580
decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)) CBlockClusterAdaptor
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:619
__host__ static constexpr __device__ auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:441
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:25
__host__ constexpr __device__ Merge_v4_no_carry(const LowLengths &low_lengths)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:43
LowLengthsScan low_lengths_scan_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:38
__host__ constexpr __device__ Merge_v4_no_carry()=default
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:35
__host__ static constexpr __device__ bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:116
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_up_diff, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:81
static constexpr index_t NDimLow
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:26
__host__ static constexpr __device__ index_t GetNumOfLowerDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:52
__host__ constexpr __device__ const auto & GetUpperLengths() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:56
__host__ constexpr __device__ void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:59
__host__ static constexpr __device__ bool IsKnownAtCompileTime()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:107
__host__ static constexpr __device__ bool IsLinearTransform()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:100
UpLengths up_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:39
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:32
__host__ static constexpr __device__ index_t GetNumOfUpperDimension()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:54
__host__ static constexpr __device__ bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:102
LowLengths low_lengths_
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:37
__host__ __device__ void Print() const
Definition: gridwise_gemm_xdlops_bwd_weight.hpp:121
Definition: xdlops_gemm.hpp:1126
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: math.hpp:34
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:334