@@ -132,6 +132,183 @@ class ONNXGraphWrapper : public ImportGraphWrapper
132132 opencv_onnx::GraphProto& net;
133133};
134134
135+ /* Fusion for Gelu.
136+
137+ Graph before fusion:
138+ +---------------------------------------------+
139+ | |
140+ [Input] -> Div[B=sqrt(2)] -> Erf -> Add[B=1] -> Mul -> Mul[B=0.5] -> [Output]
141+
142+ Graph after fusion:
143+ [Input] -> Gelu -> [Output]
144+
145+ */
146+ class GeluSubGraph :public Subgraph
147+ {
148+ public:
149+ GeluSubGraph ()
150+ {
151+ int input =addNodeToMatch (" " );
152+ int div =addNodeToMatch (" Div" , input,addNodeToMatch (" " )/* B=sqrt(2)*/ );
153+ int erf =addNodeToMatch (" Erf" , div);
154+ int add =addNodeToMatch (" Add" , erf,addNodeToMatch (" " )/* B=1*/ );
155+ int mul =addNodeToMatch (" Mul" , input, add);
156+ addNodeToMatch (" Mul" , mul,addNodeToMatch (" " )/* B=0.5*/ ) ;
157+
158+ setFusedNode (" Gelu" , input);
159+ }
160+
161+ static bool isWithInitializer (const std::vector<int >& matchedNodesIds)
162+ {
163+ // if node.getType() is Constant, Constant nodes are placed between other nodes
164+ if (matchedNodesIds[2 ] - matchedNodesIds[1 ] !=1 )
165+ return false ;
166+ // if Initializer, there is no Constant node between other nodes
167+ return true ;
168+ }
169+
170+ static float extractConstant (const Ptr<ImportGraphWrapper>& net,int node_id,int input_id,bool withInitializer)
171+ {
172+ if (withInitializer)
173+ {
174+ auto onnx_net = net.dynamicCast <ONNXGraphWrapper>();
175+ int initializer_id = onnx_net->getInputInitializerId (node_id, input_id);
176+ Mat const_mat = onnx_net->getMatFromInitializer (initializer_id);
177+ return *const_mat.ptr <float >();
178+ }else {
179+ const Ptr<ImportNodeWrapper> node = net->getNode (node_id);
180+ int constant_id =getInputNodeId (net, node, input_id);
181+ Ptr<ImportNodeWrapper> constant_ptr = net->getNode (constant_id);
182+ opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast <ONNXNodeWrapper>()->node ;
183+ opencv_onnx::TensorProto constant_proto = constant_node->attribute (0 ).t ();
184+ Mat constant_mat =getMatFromTensor (constant_proto);
185+ return *constant_mat.ptr <float >();
186+ }
187+ }
188+
189+ virtual bool match (const Ptr<ImportGraphWrapper>& net,int nodeId,
190+ std::vector<int >& matchedNodesIds,
191+ std::vector<int >& targetNodesIds) CV_OVERRIDE
192+ {
193+ if (Subgraph::match (net, nodeId, matchedNodesIds, targetNodesIds))
194+ {
195+ bool withInitializer =isWithInitializer (matchedNodesIds);
196+
197+ // Check Div[B=sqrt(2)]
198+ float divisor =extractConstant (net, matchedNodesIds[0 ],1 , withInitializer);
199+ if (divisor - M_SQRT2 >=1e-6 )
200+ return false ;
201+
202+ // Check Add[B=1]
203+ float add_const =extractConstant (net, matchedNodesIds[2 ],1 , withInitializer);
204+ if (add_const -1 .f >=1e-6 )
205+ return false ;
206+
207+ // Check Mul[B=0.5]
208+ float mul_const =extractConstant (net, matchedNodesIds[4 ],1 , withInitializer);
209+ if (mul_const -0 .5f >=1e-6 )
210+ return false ;
211+
212+ return true ;
213+ }
214+ return false ;
215+ }
216+ };
217+
218+ /* Fusion for GeluApproximation.
219+
220+ Graph before fusion:
221+ +--------+------+----------------+------------------------------------+
222+ | | | | |
223+ [Input] -> Mul -> Mul -> Mul[ ] -> Add -> Mul[ ] -> Tanh -> Add[A=1] -> Mul -> Mul(A=0.5) -> [Output]
224+ / \
225+ A=0.044714998453855515 A=sqrt(2/pie)
226+
227+ Graph after fusion:
228+ [Input] -> GeluApproximation -> [Output]
229+
230+ */
231+ class GeluApproximationSubGraph :public Subgraph
232+ {
233+ public:
234+ GeluApproximationSubGraph ()
235+ {
236+ int input =addNodeToMatch (" " );
237+ int mul0 =addNodeToMatch (" Mul" , input, input);
238+ int mul1 =addNodeToMatch (" Mul" , input, mul0);
239+ int mul2 =addNodeToMatch (" Mul" ,addNodeToMatch (" " )/* A=0.044714998453855515*/ , mul1);
240+ int add0 =addNodeToMatch (" Add" , input, mul2);
241+ int mul3 =addNodeToMatch (" Mul" ,addNodeToMatch (" " )/* A=sqrt(2/pie)*/ , add0);
242+ int tanh =addNodeToMatch (" Tanh" , mul3);
243+ int add1 =addNodeToMatch (" Add" ,addNodeToMatch (" " )/* A=1*/ , tanh);
244+ int mul4 =addNodeToMatch (" Mul" , input, add1);
245+ addNodeToMatch (" Mul" ,addNodeToMatch (" " )/* A=0.5*/ , mul4);
246+
247+ setFusedNode (" GeluApproximation" , input);
248+ }
249+
250+ static bool isWithInitializer (const std::vector<int >& matchedNodesIds)
251+ {
252+ // if node.getType() is Constant, Constant nodes are placed between other nodes
253+ if (matchedNodesIds[2 ] - matchedNodesIds[1 ] !=1 )
254+ return false ;
255+ // if Initializer, there is no Constant node between other nodes
256+ return true ;
257+ }
258+
259+ static float extractConstant (const Ptr<ImportGraphWrapper>& net,int node_id,int input_id,bool withInitializer)
260+ {
261+ if (withInitializer)
262+ {
263+ auto onnx_net = net.dynamicCast <ONNXGraphWrapper>();
264+ int initializer_id = onnx_net->getInputInitializerId (node_id, input_id);
265+ Mat const_mat = onnx_net->getMatFromInitializer (initializer_id);
266+ return *const_mat.ptr <float >();
267+ }else {
268+ const Ptr<ImportNodeWrapper> node = net->getNode (node_id);
269+ int constant_id =getInputNodeId (net, node, input_id);
270+ Ptr<ImportNodeWrapper> constant_ptr = net->getNode (constant_id);
271+ opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast <ONNXNodeWrapper>()->node ;
272+ opencv_onnx::TensorProto constant_proto = constant_node->attribute (0 ).t ();
273+ Mat constant_mat =getMatFromTensor (constant_proto);
274+ return *constant_mat.ptr <float >();
275+ }
276+ }
277+
278+ virtual bool match (const Ptr<ImportGraphWrapper>& net,int nodeId,
279+ std::vector<int >& matchedNodesIds,
280+ std::vector<int >& targetNodesIds) CV_OVERRIDE
281+ {
282+ if (Subgraph::match (net, nodeId, matchedNodesIds, targetNodesIds))
283+ {
284+ bool withInitializer =isWithInitializer (matchedNodesIds);
285+
286+ // Check Mul[A=0.044714998453855515]
287+ float coef =extractConstant (net, matchedNodesIds[2 ],0 , withInitializer);
288+ if (coef -0.044714998453855515 >=1e-6 )
289+ return false ;
290+
291+ // Check Mul[A=sqrt(2/pie)]
292+ float sqrt_2_pie =extractConstant (net, matchedNodesIds[4 ],0 , withInitializer);
293+ if (sqrt_2_pie -0.7978845834732056 >=1e-6 )
294+ return false ;
295+
296+ // Check Add[A=1]
297+ float add_const =extractConstant (net, matchedNodesIds[6 ],0 , withInitializer);
298+ if (add_const -1 .f >=1e-6 )
299+ return false ;
300+
301+ // Check Mul[A=0.5]
302+ float mul_const =extractConstant (net, matchedNodesIds[8 ],0 , withInitializer);
303+ if (mul_const -0 .5f >=1e-6 )
304+ return false ;
305+
306+ return true ;
307+ }
308+ return false ;
309+ }
310+ };
311+
135312class LayerNormSubGraph :public Subgraph
136313{
137314public:
@@ -904,6 +1081,8 @@ class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
9041081void simplifySubgraphs (opencv_onnx::GraphProto& net)
9051082{
9061083 std::vector<Ptr<Subgraph> > subgraphs;
1084+ subgraphs.push_back (makePtr<GeluSubGraph>());
1085+ subgraphs.push_back (makePtr<GeluApproximationSubGraph>());
9071086 subgraphs.push_back (makePtr<LayerNormSubGraph>());
9081087 subgraphs.push_back (makePtr<GatherCastSubgraph>());
9091088 subgraphs.push_back (makePtr<MulCastSubgraph>());