/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp Source File
rmsnorm2d_fwd_traits.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
11 {
12  NO_ADD = 0,
13  // fused add before RMSNorm and store result to global
14  PRE_ADD_STORE = 1,
15  // fused add before RMSNorm, but not store result
16  PRE_ADD = 2,
17 };
18 
19 // clang-format off
20 template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
21 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
22 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
23 template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
24 // clang-format on
25 
27 {
28  NO_SWEEP = 0,
29  SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
30  DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
31 };
32 
33 // clang-format off
34 template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
35 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
36 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
37 template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
38 // clang-format on
39 
41 {
43  // T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
44  // architecture designed for a variety of NLP tasks. This option mimics T5's approach to
45  // RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
46  T5_MODEL_LIKE = 1,
47 };
48 
49 // clang-format off
50 template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
51 template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
52 template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
53 // clang-format on
54 
55 template <bool kPadN_,
56  bool kSaveInvRms_,
57  bool kSaveUnquant_,
58  bool kTwoPass_,
59  Rmsnorm2dFusedAddEnum kFusedAdd_,
60  Rmsnorm2dFusedQuantEnum kFusedQuant_,
61  Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
63 {
64  static constexpr bool kPadN = kPadN_;
65  static constexpr bool kSaveInvRms = kSaveInvRms_;
66  static constexpr bool kSaveUnquant = kSaveUnquant_;
67  static constexpr bool kTwoPass = kTwoPass_;
68  static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
69  static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
70  static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
71 };
72 
73 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
Rmsnorm2dSensitiveEnum
Definition: rmsnorm2d_fwd_traits.hpp:41
Rmsnorm2dFusedQuantEnum
Definition: rmsnorm2d_fwd_traits.hpp:27
Rmsnorm2dFusedAddEnum
Definition: rmsnorm2d_fwd_traits.hpp:11
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_traits.hpp:63
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_traits.hpp:66
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd
Definition: rmsnorm2d_fwd_traits.hpp:68
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm
Definition: rmsnorm2d_fwd_traits.hpp:70
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant
Definition: rmsnorm2d_fwd_traits.hpp:69
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_traits.hpp:67
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_traits.hpp:64
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_traits.hpp:65
Definition: rmsnorm2d_fwd_traits.hpp:50