33// of this distribution and at http://opencv.org/license.html.
44
55#include " ../precomp.hpp"
6+ #include " cpu_kernels/fast_gemm.hpp"
7+
8+ #include < opencv2/dnn/shape_utils.hpp>
69
710namespace cv {namespace dnn {
811
12+ static void packWeight (size_t num_heads,size_t head_size,size_t input_hidden_size,
13+ const float *weight_data,size_t hidden_size, std::vector<float > &packed_weight,const FastGemmOpt &opt) {
14+ // num_heads * pack(head_size, input_hidden_size)
15+ size_t pack_size =fastGemmPackBSize (head_size, input_hidden_size, opt);
16+ size_t packed_weight_size = num_heads * pack_size;
17+ packed_weight.resize (packed_weight_size,0 .f );
18+ auto *packed_weight_data = packed_weight.data ();
19+ for (size_t i =0 ; i < num_heads; i++) {
20+ fastGemmPackB (false , head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt);
21+ packed_weight_data += pack_size;
22+ weight_data += head_size;
23+ }
24+ }
25+
926// Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention
1027class AttentionLayerImpl CV_FINAL : public AttentionLayer {
1128public:
@@ -19,14 +36,19 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
1936auto param_qkv_hidden_sizes = params.get (" qkv_hidden_sizes" );
2037CV_CheckEQ (param_qkv_hidden_sizes.size (),3 ," DNN/Attention: qkv_hidden_sizes must and only have three elements" );
2138 qkv_hidden_sizes.resize (3 );
22- qkv_hidden_sizes[0 ] = param_qkv_hidden_sizes.get <int >(0 );
23- qkv_hidden_sizes[1 ] = param_qkv_hidden_sizes.get <int >(1 );
24- qkv_hidden_sizes[2 ] = param_qkv_hidden_sizes.get <int >(2 );
39+ qkv_hidden_sizes[0 ] =static_cast <size_t >(param_qkv_hidden_sizes.get <int >(0 ));
40+ qkv_hidden_sizes[1 ] =static_cast <size_t >(param_qkv_hidden_sizes.get <int >(1 ));
41+ qkv_hidden_sizes[2 ] =static_cast <size_t >(param_qkv_hidden_sizes.get <int >(2 ));
42+
43+ hidden_size = qkv_hidden_sizes[0 ] + qkv_hidden_sizes[1 ] + qkv_hidden_sizes[2 ];
2544
26- qk_head_size =static_cast <int >(qkv_hidden_sizes[0 ] / num_heads);
27- v_head_size =static_cast <int >(qkv_hidden_sizes[2 ] / num_heads);
45+ qkv_head_sizes.resize (3 );
46+ std::transform (qkv_hidden_sizes.begin (), qkv_hidden_sizes.end (),std::back_inserter (qkv_head_sizes),
47+ [this ] (const size_t w) {return static_cast <size_t >(w / num_heads); });
2848
29- scale = params.get <float >(" scale" ,sqrt (1 .f / qk_head_size));
49+ scale = params.get <float >(" scale" ,sqrt (1 .f / qkv_head_sizes[0 ]));
50+
51+ is_prepacked =false ;
3052 }
3153
3254virtual bool supportBackend (int backendId) CV_OVERRIDE {
@@ -37,19 +59,34 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
3759const int requiredOutputs,
3860 std::vector<MatShape> &outputs,
3961 std::vector<MatShape> &internals)const CV_OVERRIDE {
40- const auto &input = inputs[0 ];
41- const auto &weight = inputs[1 ];
42- const auto &bias = inputs[2 ];
43- int dim_bias =std::accumulate (bias.begin (), bias.end (),1 , std::multiplies<int >());
62+ CV_CheckEQ (inputs.size (),static_cast <size_t >(3 )," DNN/Attention: three inputs are required" );
63+ const auto &input_shape = inputs[0 ];
64+ const auto &weight_shape = inputs[1 ];
65+ const auto &bias_shape = inputs[2 ];
66+ size_t dim_bias =static_cast <size_t >(std::accumulate (bias_shape.begin (), bias_shape.end (),1 , std::multiplies<int >()));
67+
68+ CV_CheckEQ (input_shape.size (),static_cast <size_t >(3 )," DNN/Attention: invalid input dimension" );
69+ CV_CheckEQ (weight_shape.size (),static_cast <size_t >(2 )," DNN/Attention: invalid weight dimension" );
4470
45- CV_CheckEQ (input. back (), weight[weight. size () - 2 ]," DNN/Attention: invalid input shape" );
46- CV_CheckEQ (weight. back (), qkv_hidden_sizes[ 0 ] + qkv_hidden_sizes[ 1 ] + qkv_hidden_sizes[ 2 ] ," DNN/Attention: invalid weight shape" );
47- CV_CheckEQ (dim_bias,qkv_hidden_sizes[ 0 ] + qkv_hidden_sizes[ 1 ] + qkv_hidden_sizes[ 2 ] ," DNN/Attention: invalid bias shape" );
71+ CV_CheckEQ (input_shape[ 2 ], weight_shape[ 0 ]," DNN/Attention: invalid input shape" );
72+ CV_CheckEQ (static_cast < size_t >(weight_shape[ 1 ]), hidden_size ," DNN/Attention: invalid weight shape" );
73+ CV_CheckEQ (dim_bias,hidden_size ," DNN/Attention: invalid bias shape" );
4874
4975 outputs.assign (1 , inputs[0 ]);
5076return false ;
5177 }
5278
79+ virtual void finalize (InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
80+ opt.init ();
81+
82+ std::vector<Mat> inputs;
83+ inputs_arr.getMatVector (inputs);
84+ const auto input_shape =shape (inputs[0 ]);
85+ batch_size =static_cast <size_t >(input_shape[0 ]);
86+ seq_len =static_cast <size_t >(input_shape[1 ]);
87+ input_hidden_size =static_cast <size_t >(input_shape[2 ]);
88+ }
89+
5390void forward (InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
5491CV_TRACE_FUNCTION ();
5592CV_TRACE_ARG_VALUE (name," name" , name.c_str ());
@@ -63,13 +100,41 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
63100 std::vector<Mat> inputs, outputs;
64101 inputs_arr.getMatVector (inputs);
65102 outputs_arr.getMatVector (outputs);
103+ const auto &input = inputs[0 ];
104+ auto &output = outputs[0 ];
105+
106+ if (!is_prepacked) {
107+ // prepack
108+ const auto &weight = inputs[1 ];
109+ const auto *weight_data = weight.ptr <const float >();
110+ packWeight (num_heads, qkv_head_sizes[0 ], input_hidden_size, weight_data, hidden_size, packed_q, opt);
111+ packWeight (num_heads, qkv_head_sizes[1 ], input_hidden_size, weight_data + qkv_hidden_sizes[0 ], hidden_size, packed_k, opt);
112+ packWeight (num_heads, qkv_head_sizes[2 ], input_hidden_size, weight_data + qkv_hidden_sizes[0 ] + qkv_hidden_sizes[1 ], hidden_size, packed_v, opt);
113+
114+ is_prepacked =true ;
115+ }
66116
67- // TODO: impl
117+ input. copyTo (output);
68118 }
69119
70120private:
71- int qk_head_size;
72- int v_head_size;
121+ size_t num_heads;
122+ std::vector<size_t > qkv_hidden_sizes;// order: {qk_hidden_size, qk_hidden_size, v_hidden_size}
123+ float scale;
124+
125+ std::vector<size_t > qkv_head_sizes;// order: {qk_head_size, qk_head_size, v_head_size}
126+
127+ size_t batch_size;
128+ size_t seq_len;
129+ size_t input_hidden_size;
130+ size_t hidden_size;
131+
132+ bool is_prepacked;
133+ std::vector<float > packed_q;
134+ std::vector<float > packed_k;
135+ std::vector<float > packed_v;
136+
137+ FastGemmOpt opt;
73138};
74139
75140Ptr<AttentionLayer>AttentionLayer::create (const LayerParams ¶ms) {