Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit304584b

Browse files
committed
resolve comments; add packWeight
1 parentabdf06d commit304584b

File tree

4 files changed

+145
-20
lines changed

4 files changed

+145
-20
lines changed

‎modules/dnn/include/opencv2/dnn/all_layers.hpp‎

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,10 +1168,6 @@ CV__DNN_INLINE_NS_BEGIN
11681168

11691169
classCV_EXPORTS AttentionLayer : public Layer {
11701170
public:
1171-
int num_heads;
1172-
std::vector<int> qkv_hidden_sizes;
1173-
float scale;
1174-
11751171
static Ptr<AttentionLayer>create(const LayerParams &params);
11761172
};
11771173

‎modules/dnn/src/layers/attention_layer.cpp‎

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,26 @@
33
// of this distribution and at http://opencv.org/license.html.
44

55
#include"../precomp.hpp"
6+
#include"cpu_kernels/fast_gemm.hpp"
7+
8+
#include<opencv2/dnn/shape_utils.hpp>
69

710
namespacecv {namespacednn {
811

12+
staticvoidpackWeight(size_t num_heads,size_t head_size,size_t input_hidden_size,
13+
constfloat *weight_data,size_t hidden_size, std::vector<float> &packed_weight,const FastGemmOpt &opt) {
14+
// num_heads * pack(head_size, input_hidden_size)
15+
size_t pack_size =fastGemmPackBSize(head_size, input_hidden_size, opt);
16+
size_t packed_weight_size = num_heads * pack_size;
17+
packed_weight.resize(packed_weight_size,0.f);
18+
auto *packed_weight_data = packed_weight.data();
19+
for (size_t i =0; i < num_heads; i++) {
20+
fastGemmPackB(false, head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt);
21+
packed_weight_data += pack_size;
22+
weight_data += head_size;
23+
}
24+
}
25+
926
// Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention
1027
classAttentionLayerImpl CV_FINAL : public AttentionLayer {
1128
public:
@@ -19,14 +36,19 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
1936
auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes");
2037
CV_CheckEQ(param_qkv_hidden_sizes.size(),3,"DNN/Attention: qkv_hidden_sizes must and only have three elements");
2138
qkv_hidden_sizes.resize(3);
22-
qkv_hidden_sizes[0] = param_qkv_hidden_sizes.get<int>(0);
23-
qkv_hidden_sizes[1] = param_qkv_hidden_sizes.get<int>(1);
24-
qkv_hidden_sizes[2] = param_qkv_hidden_sizes.get<int>(2);
39+
qkv_hidden_sizes[0] =static_cast<size_t>(param_qkv_hidden_sizes.get<int>(0));
40+
qkv_hidden_sizes[1] =static_cast<size_t>(param_qkv_hidden_sizes.get<int>(1));
41+
qkv_hidden_sizes[2] =static_cast<size_t>(param_qkv_hidden_sizes.get<int>(2));
42+
43+
hidden_size = qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2];
2544

26-
qk_head_size =static_cast<int>(qkv_hidden_sizes[0] / num_heads);
27-
v_head_size =static_cast<int>(qkv_hidden_sizes[2] / num_heads);
45+
qkv_head_sizes.resize(3);
46+
std::transform(qkv_hidden_sizes.begin(), qkv_hidden_sizes.end(),std::back_inserter(qkv_head_sizes),
47+
[this] (constsize_t w) {returnstatic_cast<size_t>(w / num_heads); });
2848

29-
scale = params.get<float>("scale",sqrt(1.f / qk_head_size));
49+
scale = params.get<float>("scale",sqrt(1.f / qkv_head_sizes[0]));
50+
51+
is_prepacked =false;
3052
}
3153

3254
virtualboolsupportBackend(int backendId) CV_OVERRIDE {
@@ -37,19 +59,34 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
3759
constint requiredOutputs,
3860
std::vector<MatShape> &outputs,
3961
std::vector<MatShape> &internals)const CV_OVERRIDE {
40-
constauto &input = inputs[0];
41-
constauto &weight = inputs[1];
42-
constauto &bias = inputs[2];
43-
int dim_bias =std::accumulate(bias.begin(), bias.end(),1, std::multiplies<int>());
62+
CV_CheckEQ(inputs.size(),static_cast<size_t>(3),"DNN/Attention: three inputs are required");
63+
constauto &input_shape = inputs[0];
64+
constauto &weight_shape = inputs[1];
65+
constauto &bias_shape = inputs[2];
66+
size_t dim_bias =static_cast<size_t>(std::accumulate(bias_shape.begin(), bias_shape.end(),1, std::multiplies<int>()));
67+
68+
CV_CheckEQ(input_shape.size(),static_cast<size_t>(3),"DNN/Attention: invalid input dimension");
69+
CV_CheckEQ(weight_shape.size(),static_cast<size_t>(2),"DNN/Attention: invalid weight dimension");
4470

45-
CV_CheckEQ(input.back(), weight[weight.size() -2],"DNN/Attention: invalid input shape");
46-
CV_CheckEQ(weight.back(), qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2],"DNN/Attention: invalid weight shape");
47-
CV_CheckEQ(dim_bias,qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2],"DNN/Attention: invalid bias shape");
71+
CV_CheckEQ(input_shape[2], weight_shape[0],"DNN/Attention: invalid input shape");
72+
CV_CheckEQ(static_cast<size_t>(weight_shape[1]), hidden_size,"DNN/Attention: invalid weight shape");
73+
CV_CheckEQ(dim_bias,hidden_size,"DNN/Attention: invalid bias shape");
4874

4975
outputs.assign(1, inputs[0]);
5076
returnfalse;
5177
}
5278

79+
virtualvoidfinalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
80+
opt.init();
81+
82+
std::vector<Mat> inputs;
83+
inputs_arr.getMatVector(inputs);
84+
constauto input_shape =shape(inputs[0]);
85+
batch_size =static_cast<size_t>(input_shape[0]);
86+
seq_len =static_cast<size_t>(input_shape[1]);
87+
input_hidden_size =static_cast<size_t>(input_shape[2]);
88+
}
89+
5390
voidforward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
5491
CV_TRACE_FUNCTION();
5592
CV_TRACE_ARG_VALUE(name,"name", name.c_str());
@@ -63,13 +100,41 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
63100
std::vector<Mat> inputs, outputs;
64101
inputs_arr.getMatVector(inputs);
65102
outputs_arr.getMatVector(outputs);
103+
constauto &input = inputs[0];
104+
auto &output = outputs[0];
105+
106+
if (!is_prepacked) {
107+
// prepack
108+
constauto &weight = inputs[1];
109+
constauto *weight_data = weight.ptr<constfloat>();
110+
packWeight(num_heads, qkv_head_sizes[0], input_hidden_size, weight_data, hidden_size, packed_q, opt);
111+
packWeight(num_heads, qkv_head_sizes[1], input_hidden_size, weight_data + qkv_hidden_sizes[0], hidden_size, packed_k, opt);
112+
packWeight(num_heads, qkv_head_sizes[2], input_hidden_size, weight_data + qkv_hidden_sizes[0] + qkv_hidden_sizes[1], hidden_size, packed_v, opt);
113+
114+
is_prepacked =true;
115+
}
66116

67-
// TODO: impl
117+
input.copyTo(output);
68118
}
69119

70120
private:
71-
int qk_head_size;
72-
int v_head_size;
121+
size_t num_heads;
122+
std::vector<size_t> qkv_hidden_sizes;// order: {qk_hidden_size, qk_hidden_size, v_hidden_size}
123+
float scale;
124+
125+
std::vector<size_t> qkv_head_sizes;// order: {qk_head_size, qk_head_size, v_head_size}
126+
127+
size_t batch_size;
128+
size_t seq_len;
129+
size_t input_hidden_size;
130+
size_t hidden_size;
131+
132+
bool is_prepacked;
133+
std::vector<float> packed_q;
134+
std::vector<float> packed_k;
135+
std::vector<float> packed_v;
136+
137+
FastGemmOpt opt;
73138
};
74139

75140
Ptr<AttentionLayer>AttentionLayer::create(const LayerParams &params) {

‎modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp‎

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@
2020

2121
namespacecv {namespacednn {
2222

23+
size_tfastGemmPackBSize(size_t N,size_t K,const FastGemmOpt &opt) {
24+
#if CV_TRY_NEON
25+
if (opt.use_neon) {
26+
returnstatic_cast<size_t>(opt_NEON::fastGemmPackBSize(N, K));
27+
}else
28+
#endif
29+
#if CV_TRY_AVX2
30+
if (opt.use_avx2) {
31+
returnstatic_cast<size_t>(opt_AVX2::fastGemmPackBSize(N, K));
32+
}else
33+
#endif
34+
#if CV_TRY_AVX
35+
if (opt.use_avx) {
36+
returnstatic_cast<size_t>(opt_AVX::fastGemmPackBSize(N, K));
37+
}else
38+
#endif
39+
#if CV_TRY_LASX
40+
if (opt.use_lasx) {
41+
returnstatic_cast<size_t>(opt_LASX::fastGemmPackBSize(N, K));
42+
}else
43+
#endif
44+
{
45+
returnstatic_cast<size_t>(cpu_baseline::fastGemmPackBSize(N, K));
46+
}
47+
}
48+
2349
voidfastGemmPackB(const Mat &B, std::vector<float> &packed_B,bool trans, FastGemmOpt &opt) {
2450
CV_CheckEQ(B.dims,2,"fastGemmPackB: input mat should be two-dimensional");
2551
CV_CheckTypeEQ(B.type(), CV_32F,"fastGemmPackB: only float32 is supported for now");
@@ -66,6 +92,41 @@ void fastGemmPackB(const Mat &B, std::vector<float> &packed_B, bool trans, FastG
6692
}
6793
}
6894

95+
voidfastGemmPackB(bool trans,size_t N,size_t K,constfloat *B,size_t ldb,float *packed_B,const FastGemmOpt &opt) {
96+
size_t ldb0 = ldb, ldb1 =1;
97+
if (trans) {
98+
std::swap(K, N);
99+
std::swap(ldb0, ldb1);
100+
}
101+
102+
constauto &b = (constchar *)B;
103+
auto *packed_b = (char *)packed_B;
104+
105+
#if CV_TRY_NEON
106+
if (opt.use_neon) {
107+
opt_NEON::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1,sizeof(float));
108+
}else
109+
#endif
110+
#if CV_TRY_AVX2
111+
if (opt.use_avx2) {
112+
opt_AVX2::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1,sizeof(float));
113+
}else
114+
#endif
115+
#if CV_TRY_AVX
116+
if (opt.use_avx) {
117+
opt_AVX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1,sizeof(float));
118+
}else
119+
#endif
120+
#if CV_TRY_LASX
121+
if (opt.use_lasx) {
122+
opt_LASX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1,sizeof(float));
123+
}else
124+
#endif
125+
{
126+
cpu_baseline::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1,sizeof(float));
127+
}
128+
}
129+
69130
staticvoidfast_gemm_thin(float alpha,float beta,int M,int N,int K,
70131
constchar *a_,int lda0,int lda1,
71132
constchar *b_,int ldb,

‎modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ struct FastGemmOpt {
4242
}
4343
};
4444

45+
size_tfastGemmPackBSize(size_t N,size_t K,const FastGemmOpt &opt);
46+
4547
voidfastGemmPackB(const Mat &m, std::vector<float> &packed_B,bool trans, FastGemmOpt &opt);
48+
voidfastGemmPackB(bool trans,size_t N,size_t K,constfloat *B,size_t ldb,float *packed_B,const FastGemmOpt &opt);
4649

4750
voidfastGemm(bool trans_a,int M,int N,int K,
4851
float alpha,constfloat *A,int lda,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp