include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File

include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File#

Composable Kernel: include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File
rmsnorm2d_fwd_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
11 {
12  NO_ADD = 0,
13  // fused add before RMSNorm and store result to global
14  PRE_ADD_STORE = 1,
15  // fused add before RMSNorm, but not store result
16  PRE_ADD = 2,
17 };
18 
19 // clang-format off
20 template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
21 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
22 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
23 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
24 // clang-format on
25 
27 {
28  NO_SWEEP = 0,
29  SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
30  DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
31 };
32 
33 // clang-format off
34 template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
35 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
36 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
37 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
38 // clang-format on
39 
40 template <bool kPadN_,
41  bool kSaveInvRms_,
42  bool kTwoPass_,
43  Rmsnorm2dFusedAddEnum kFusedAdd_,
44  Rmsnorm2dFusedQuantEnum kFusedQuant_>
46 {
47  static constexpr bool kPadN = kPadN_;
48  static constexpr bool kSaveInvRms = kSaveInvRms_;
49  static constexpr bool kTwoPass = kTwoPass_;
50  static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
51  static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
52 };
53 
54 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
Rmsnorm2dFusedQuantEnum
Definition: rmsnorm2d_fwd_traits.hpp:27
Rmsnorm2dFusedAddEnum
Definition: rmsnorm2d_fwd_traits.hpp:11
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_traits.hpp:46
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_traits.hpp:48
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_traits.hpp:49
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_traits.hpp:47
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant
Definition: rmsnorm2d_fwd_traits.hpp:51
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd
Definition: rmsnorm2d_fwd_traits.hpp:50