Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5cd9d50

Browse files
fengyuentaugeversonsto
authored andcommitted
Merge pull requestopencv#23219 from fengyuentau:add_gelu
Add GELU layer for vision transformers* add gelu and gelu approximation* drop setKernelParams
1 parentf577ff2 commit5cd9d50

File tree

7 files changed

+292
-1
lines changed

7 files changed

+292
-1
lines changed

‎modules/dnn/include/opencv2/dnn/all_layers.hpp‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,18 @@ CV__DNN_INLINE_NS_BEGIN
806806
static Ptr<SeluLayer>create(const LayerParams &params);
807807
};
808808

809+
classCV_EXPORTS GeluLayer : public ActivationLayer
810+
{
811+
public:
812+
static Ptr<GeluLayer>create(const LayerParams &params);
813+
};
814+
815+
classCV_EXPORTS GeluApproximationLayer : public ActivationLayer
816+
{
817+
public:
818+
static Ptr<GeluApproximationLayer>create(const LayerParams &params);
819+
};
820+
809821
classCV_EXPORTS ThresholdedReluLayer : public ActivationLayer
810822
{
811823
public:

‎modules/dnn/src/init.cpp‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ void initializeLayerFactory()
145145
CV_DNN_REGISTER_LAYER_CLASS(HardSigmoid, HardSigmoidLayer);
146146
CV_DNN_REGISTER_LAYER_CLASS(Selu, SeluLayer);
147147
CV_DNN_REGISTER_LAYER_CLASS(ThresholdedRelu,ThresholdedReluLayer);
148+
CV_DNN_REGISTER_LAYER_CLASS(Gelu, GeluLayer);
149+
CV_DNN_REGISTER_LAYER_CLASS(GeluApproximation, GeluApproximationLayer);
148150
CV_DNN_REGISTER_LAYER_CLASS(BatchNorm, BatchNormLayer);
149151
CV_DNN_REGISTER_LAYER_CLASS(MaxUnpool, MaxUnpoolLayer);
150152
CV_DNN_REGISTER_LAYER_CLASS(Dropout, BlankLayer);

‎modules/dnn/src/layers/elementwise_layers.cpp‎

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,57 @@ struct BaseDefaultFunctor : public BaseFunctor
837837
staticconstchar*const ocl_kernel_name;
838838
};
839839

840+
structGeluFunctor :publicBaseDefaultFunctor<GeluFunctor>
841+
{
842+
typedef GeluLayer Layer;
843+
844+
explicitGeluFunctor() {}
845+
846+
boolsupportBackend(int backendId,int)
847+
{
848+
return backendId == DNN_BACKEND_OPENCV;
849+
}
850+
851+
inlinefloatcalculate(float x)const
852+
{
853+
return0.5f * x * (1.0f +erf(x * M_SQRT1_2));
854+
}
855+
856+
int64getFLOPSPerElement()const {return100; }
857+
};
858+
859+
template<>
860+
constchar*const BaseDefaultFunctor<GeluFunctor>::ocl_kernel_name ="GeluForward";
861+
862+
namespaceGeluApproximationConstants
863+
{
864+
staticconstexprfloat sqrt_2_pi =0.7978845834732056f;
865+
staticconstexprfloat coef_sqrt_2_pi =0.044714998453855515f * sqrt_2_pi;
866+
}
867+
868+
structGeluApproximationFunctor :publicBaseDefaultFunctor<GeluApproximationFunctor>
869+
{
870+
typedef GeluApproximationLayer Layer;
871+
872+
explicitGeluApproximationFunctor() {}
873+
874+
boolsupportBackend(int backendId,int)
875+
{
876+
return backendId == DNN_BACKEND_OPENCV;
877+
}
878+
879+
inlinefloatcalculate(float x)const
880+
{
881+
return0.5f * x * (1.f +tanh(x * (GeluApproximationConstants::sqrt_2_pi +
882+
GeluApproximationConstants::coef_sqrt_2_pi * x * x)));
883+
}
884+
885+
int64getFLOPSPerElement()const {return100; }
886+
};
887+
888+
template<>
889+
constchar*const BaseDefaultFunctor<GeluApproximationFunctor>::ocl_kernel_name ="GeluApproximationForward";
890+
840891
structTanHFunctor :publicBaseDefaultFunctor<TanHFunctor>
841892
{
842893
typedef TanHLayer Layer;
@@ -2694,6 +2745,22 @@ Ptr<ReLU6Layer> ReLU6Layer::create(const LayerParams& params)
26942745
return l;
26952746
}
26962747

2748+
Ptr<GeluLayer>GeluLayer::create(const LayerParams& params)
2749+
{
2750+
Ptr<GeluLayer>l(new ElementWiseLayer<GeluFunctor>(GeluFunctor()));
2751+
l->setParamsFrom(params);
2752+
2753+
return l;
2754+
}
2755+
2756+
Ptr<GeluApproximationLayer>GeluApproximationLayer::create(const LayerParams& params)
2757+
{
2758+
Ptr<GeluApproximationLayer>l(new ElementWiseLayer<GeluApproximationFunctor>(GeluApproximationFunctor()));
2759+
l->setParamsFrom(params);
2760+
2761+
return l;
2762+
}
2763+
26972764
Ptr<TanHLayer>TanHLayer::create(const LayerParams& params)
26982765
{
26992766
Ptr<TanHLayer>l(new ElementWiseLayer<TanHFunctor>());

‎modules/dnn/src/onnx/onnx_graph_simplifier.cpp‎

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,183 @@ class ONNXGraphWrapper : public ImportGraphWrapper
132132
opencv_onnx::GraphProto& net;
133133
};
134134

135+
/* Fusion for Gelu.
136+
137+
Graph before fusion:
138+
+---------------------------------------------+
139+
| |
140+
[Input] -> Div[B=sqrt(2)] -> Erf -> Add[B=1] -> Mul -> Mul[B=0.5] -> [Output]
141+
142+
Graph after fusion:
143+
[Input] -> Gelu -> [Output]
144+
145+
*/
146+
classGeluSubGraph :publicSubgraph
147+
{
148+
public:
149+
GeluSubGraph()
150+
{
151+
int input =addNodeToMatch("");
152+
int div =addNodeToMatch("Div", input,addNodeToMatch("")/* B=sqrt(2)*/ );
153+
int erf =addNodeToMatch("Erf", div);
154+
int add =addNodeToMatch("Add", erf,addNodeToMatch("")/* B=1*/ );
155+
int mul =addNodeToMatch("Mul", input, add);
156+
addNodeToMatch("Mul", mul,addNodeToMatch("")/* B=0.5*/) ;
157+
158+
setFusedNode("Gelu", input);
159+
}
160+
161+
staticboolisWithInitializer(const std::vector<int>& matchedNodesIds)
162+
{
163+
// if node.getType() is Constant, Constant nodes are placed between other nodes
164+
if (matchedNodesIds[2] - matchedNodesIds[1] !=1)
165+
returnfalse;
166+
// if Initializer, there is no Constant node between other nodes
167+
returntrue;
168+
}
169+
170+
staticfloatextractConstant(const Ptr<ImportGraphWrapper>& net,int node_id,int input_id,bool withInitializer)
171+
{
172+
if (withInitializer)
173+
{
174+
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
175+
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
176+
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
177+
return *const_mat.ptr<float>();
178+
}else {
179+
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
180+
int constant_id =getInputNodeId(net, node, input_id);
181+
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
182+
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
183+
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
184+
Mat constant_mat =getMatFromTensor(constant_proto);
185+
return *constant_mat.ptr<float>();
186+
}
187+
}
188+
189+
virtualboolmatch(const Ptr<ImportGraphWrapper>& net,int nodeId,
190+
std::vector<int>& matchedNodesIds,
191+
std::vector<int>& targetNodesIds) CV_OVERRIDE
192+
{
193+
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
194+
{
195+
bool withInitializer =isWithInitializer(matchedNodesIds);
196+
197+
// Check Div[B=sqrt(2)]
198+
float divisor =extractConstant(net, matchedNodesIds[0],1, withInitializer);
199+
if (divisor - M_SQRT2 >=1e-6)
200+
returnfalse;
201+
202+
// Check Add[B=1]
203+
float add_const =extractConstant(net, matchedNodesIds[2],1, withInitializer);
204+
if (add_const -1.f >=1e-6)
205+
returnfalse;
206+
207+
// Check Mul[B=0.5]
208+
float mul_const =extractConstant(net, matchedNodesIds[4],1, withInitializer);
209+
if (mul_const -0.5f >=1e-6)
210+
returnfalse;
211+
212+
returntrue;
213+
}
214+
returnfalse;
215+
}
216+
};
217+
218+
/* Fusion for GeluApproximation.
219+
220+
Graph before fusion:
221+
+--------+------+----------------+------------------------------------+
222+
| | | | |
223+
[Input] -> Mul -> Mul -> Mul[ ] -> Add -> Mul[ ] -> Tanh -> Add[A=1] -> Mul -> Mul(A=0.5) -> [Output]
224+
/ \
225+
A=0.044714998453855515 A=sqrt(2/pie)
226+
227+
Graph after fusion:
228+
[Input] -> GeluApproximation -> [Output]
229+
230+
*/
231+
classGeluApproximationSubGraph :publicSubgraph
232+
{
233+
public:
234+
GeluApproximationSubGraph()
235+
{
236+
int input =addNodeToMatch("");
237+
int mul0 =addNodeToMatch("Mul", input, input);
238+
int mul1 =addNodeToMatch("Mul", input, mul0);
239+
int mul2 =addNodeToMatch("Mul",addNodeToMatch("")/* A=0.044714998453855515*/, mul1);
240+
int add0 =addNodeToMatch("Add", input, mul2);
241+
int mul3 =addNodeToMatch("Mul",addNodeToMatch("")/* A=sqrt(2/pie)*/, add0);
242+
int tanh =addNodeToMatch("Tanh", mul3);
243+
int add1 =addNodeToMatch("Add",addNodeToMatch("")/* A=1*/, tanh);
244+
int mul4 =addNodeToMatch("Mul", input, add1);
245+
addNodeToMatch("Mul",addNodeToMatch("")/* A=0.5*/, mul4);
246+
247+
setFusedNode("GeluApproximation", input);
248+
}
249+
250+
staticboolisWithInitializer(const std::vector<int>& matchedNodesIds)
251+
{
252+
// if node.getType() is Constant, Constant nodes are placed between other nodes
253+
if (matchedNodesIds[2] - matchedNodesIds[1] !=1)
254+
returnfalse;
255+
// if Initializer, there is no Constant node between other nodes
256+
returntrue;
257+
}
258+
259+
staticfloatextractConstant(const Ptr<ImportGraphWrapper>& net,int node_id,int input_id,bool withInitializer)
260+
{
261+
if (withInitializer)
262+
{
263+
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
264+
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
265+
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
266+
return *const_mat.ptr<float>();
267+
}else {
268+
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
269+
int constant_id =getInputNodeId(net, node, input_id);
270+
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
271+
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
272+
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
273+
Mat constant_mat =getMatFromTensor(constant_proto);
274+
return *constant_mat.ptr<float>();
275+
}
276+
}
277+
278+
virtualboolmatch(const Ptr<ImportGraphWrapper>& net,int nodeId,
279+
std::vector<int>& matchedNodesIds,
280+
std::vector<int>& targetNodesIds) CV_OVERRIDE
281+
{
282+
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
283+
{
284+
bool withInitializer =isWithInitializer(matchedNodesIds);
285+
286+
// Check Mul[A=0.044714998453855515]
287+
float coef =extractConstant(net, matchedNodesIds[2],0, withInitializer);
288+
if (coef -0.044714998453855515 >=1e-6)
289+
returnfalse;
290+
291+
// Check Mul[A=sqrt(2/pie)]
292+
float sqrt_2_pie =extractConstant(net, matchedNodesIds[4],0, withInitializer);
293+
if (sqrt_2_pie -0.7978845834732056 >=1e-6)
294+
returnfalse;
295+
296+
// Check Add[A=1]
297+
float add_const =extractConstant(net, matchedNodesIds[6],0, withInitializer);
298+
if (add_const -1.f >=1e-6)
299+
returnfalse;
300+
301+
// Check Mul[A=0.5]
302+
float mul_const =extractConstant(net, matchedNodesIds[8],0, withInitializer);
303+
if (mul_const -0.5f >=1e-6)
304+
returnfalse;
305+
306+
returntrue;
307+
}
308+
returnfalse;
309+
}
310+
};
311+
135312
classLayerNormSubGraph :publicSubgraph
136313
{
137314
public:
@@ -904,6 +1081,8 @@ class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
9041081
voidsimplifySubgraphs(opencv_onnx::GraphProto& net)
9051082
{
9061083
std::vector<Ptr<Subgraph> > subgraphs;
1084+
subgraphs.push_back(makePtr<GeluSubGraph>());
1085+
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
9071086
subgraphs.push_back(makePtr<LayerNormSubGraph>());
9081087
subgraphs.push_back(makePtr<GatherCastSubgraph>());
9091088
subgraphs.push_back(makePtr<MulCastSubgraph>());

‎modules/dnn/src/onnx/onnx_importer.cpp‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4051,7 +4051,8 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
40514051
std::vector<std::string> simpleLayers{"Acos","Acosh","Asin","Asinh","Atan","Atanh","Ceil","Celu","Cos",
40524052
"Cosh","Dropout","Erf","Exp","Floor","HardSigmoid","HardSwish",
40534053
"Identity","Log","Round","Reciprocal","Selu","Sign","Sigmoid","Sin","Sinh","Softmax",
4054-
"Softplus","Softsign","Shrink","Sqrt","Tan","ThresholdedRelu"};
4054+
"Softplus","Softsign","Shrink","Sqrt","Tan","ThresholdedRelu","Gelu",
4055+
"GeluApproximation"};
40554056
for (constauto& name : simpleLayers)
40564057
{
40574058
dispatch[name] = &ONNXImporter::parseSimpleLayers;

‎modules/dnn/src/opencl/activations.cl‎

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,30 @@ __kernel void ThresholdedReluForward(const int n, __global T* in, __global T* ou
307307
out[index]= (in[index]>alpha ?in[index] :0.f);
308308
}
309309

310+
__kernelvoidGeluForward(constintn,__globalT*in,__globalT*out)
311+
{
312+
intindex=get_global_id(0);
313+
if (index<n)
314+
{
315+
Tx=in[index];
316+
out[index]= (T)0.5f*x* ( (T)1.f+erf(x*M_SQRT1_2) );
317+
}
318+
}
319+
320+
__kernelvoidGeluApproximationForward(constintn,__globalT*in,__globalT*out)
321+
{
322+
// see GeluApproximationConstants from modules/dnn/src/layers/elementwise_layers.cpp
323+
constTsqrt_2_pi=0.7978845834732056f;
324+
constTcoef_sqrt_2_pi=0.044714998453855515f*sqrt_2_pi;
325+
326+
intindex=get_global_id(0);
327+
if(index<n)
328+
{
329+
Tx=in[index];
330+
out[index]= (T)0.5f*x* ( (T)1.f+tanh(x* (sqrt_2_pi+coef_sqrt_2_pi*x*x)) );
331+
}
332+
}
333+
310334
__kernelvoidShrinkForward(constintn,__globalT*in,__globalT*out,
311335
constKERNEL_ARG_DTYPEbias,
312336
constKERNEL_ARG_DTYPElambd)

‎modules/dnn/test/test_onnx_importer.cpp‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,12 @@ TEST_P(Test_ONNX_layers, LayerNormExpanded)
24562456
testONNXModels("layer_norm_expanded_with_initializers");
24572457
}
24582458

2459+
TEST_P(Test_ONNX_layers, Gelu)
2460+
{
2461+
testONNXModels("gelu");
2462+
testONNXModels("gelu_approximation");
2463+
}
2464+
24592465
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
24602466

24612467
}}// namespace

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp