1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
|
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
cute::copy(Cos(_, m, k), rCos(_, m, k));
cute::copy(Sin(_, m, k), rSin(_, m, k));
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS) / 2; ++i) {
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
S_fp32(2 * i) = real;
S_fp32(2 * i + 1) = imag;
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
Tensor<Engine3, Layout3> const &identity_MN,
const int max_MN, const int min_MN,
const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
Tensor rCos = make_fragment_like(Cos);
Tensor rSin = make_fragment_like(Sin);
Tensor rS = make_fragment_like(S);
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
cute::copy(S(_, m, k), rS(_, m, k));
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
cute::copy(gS_other, rS_other);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
cute::copy(gCos, rCos(_, m, k));
cute::copy(gSin, rSin(_, m, k));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
Tensor S_other_fp32 = convert_type<float>(rS_other);
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
#pragma unroll
for (int i = 0; i < size<0>(rS); ++i) {
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
}
// Idk but I need to copy for the convert_type to work
Tensor S_fp32_copy = make_fragment_like(S_fp32);
cute::copy(S_fp32, S_fp32_copy);
using T = typename Engine0::value_type;
Tensor S_og_type = convert_type<T>(S_fp32_copy);
cute::copy(S_og_type, rS(_, m, k));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
|