1+ '''
2+ Text detection model: https://github.com/argman/EAST
3+ Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
4+ Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
5+ How to convert from pb to onnx:
6+ Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
7+ import torch
8+ import models.crnn as CRNN
9+ model = CRNN(32, 1, 37, 256)
10+ model.load_state_dict(torch.load('crnn.pth'))
11+ dummy_input = torch.randn(1, 1, 32, 100)
12+ torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
13+ '''
14+
15+
116# Import required modules
17+ import numpy as np
218import cv2 as cv
319import math
420import argparse
521
622############ Add argument parser for command line arguments ############
7- parser = argparse .ArgumentParser (description = 'Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)' )
8- parser .add_argument ('--input' ,help = 'Path to input image or video file. Skip this argument to capture frames from a camera.' )
9- parser .add_argument ('--model' ,required = True ,
10- help = 'Path to a binary .pb file of model contains trained weights.' )
23+ parser = argparse .ArgumentParser (
24+ description = "Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
25+ "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
26+ "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch" )
27+ parser .add_argument ('--input' ,
28+ help = 'Path to input image or video file. Skip this argument to capture frames from a camera.' )
29+ parser .add_argument ('--model' ,'-m' ,required = True ,
30+ help = 'Path to a binary .pb file contains trained detector network.' )
31+ parser .add_argument ('--ocr' ,default = "crnn.onnx" ,
32+ help = "Path to a binary .pb or .onnx file contains trained recognition network" , )
1133parser .add_argument ('--width' ,type = int ,default = 320 ,
1234help = 'Preprocess input image by resizing to a specific width. It should be multiple by 32.' )
13- parser .add_argument ('--height' ,type = int ,default = 320 ,
35+ parser .add_argument ('--height' ,type = int ,default = 320 ,
1436help = 'Preprocess input image by resizing to a specific height. It should be multiple by 32.' )
15- parser .add_argument ('--thr' ,type = float ,default = 0.5 ,
37+ parser .add_argument ('--thr' ,type = float ,default = 0.5 ,
1638help = 'Confidence threshold.' )
17- parser .add_argument ('--nms' ,type = float ,default = 0.4 ,
39+ parser .add_argument ('--nms' ,type = float ,default = 0.4 ,
1840help = 'Non-maximum suppression threshold.' )
1941args = parser .parse_args ()
2042
43+
2144############ Utility functions ############
22- def decode (scores ,geometry ,scoreThresh ):
45+
46+ def fourPointsTransform (frame ,vertices ):
47+ vertices = np .asarray (vertices )
48+ outputSize = (100 ,32 )
49+ targetVertices = np .array ([
50+ [0 ,outputSize [1 ]- 1 ],
51+ [0 ,0 ],
52+ [outputSize [0 ]- 1 ,0 ],
53+ [outputSize [0 ]- 1 ,outputSize [1 ]- 1 ]],dtype = "float32" )
54+
55+ rotationMatrix = cv .getPerspectiveTransform (vertices ,targetVertices )
56+ result = cv .warpPerspective (frame ,rotationMatrix ,outputSize )
57+ return result
58+
59+
60+ def decodeText (scores ):
61+ text = ""
62+ alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
63+ for i in range (scores .shape [0 ]):
64+ c = np .argmax (scores [i ][0 ])
65+ if c != 0 :
66+ text += alphabet [c - 1 ]
67+ else :
68+ text += '-'
69+
70+ # adjacent same letters as well as background text must be removed to get the final output
71+ char_list = []
72+ for i in range (len (text )):
73+ if text [i ]!= '-' and (not (i > 0 and text [i ]== text [i - 1 ])):
74+ char_list .append (text [i ])
75+ return '' .join (char_list )
76+
77+
78+ def decodeBoundingBoxes (scores ,geometry ,scoreThresh ):
2379detections = []
2480confidences = []
2581
@@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
47103score = scoresData [x ]
48104
49105# If score is lower than threshold score, move to next x
50- if (score < scoreThresh ):
106+ if (score < scoreThresh ):
51107continue
52108
53109# Calculate offset
@@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
66122
67123# Find points for rectangle
68124p1 = (- sinA * h + offset [0 ],- cosA * h + offset [1 ])
69- p3 = (- cosA * w + offset [0 ],sinA * w + offset [1 ])
70- center = (0.5 * (p1 [0 ]+ p3 [0 ]),0.5 * (p1 [1 ]+ p3 [1 ]))
71- detections .append ((center , (w ,h ),- 1 * angle * 180.0 / math .pi ))
125+ p3 = (- cosA * w + offset [0 ],sinA * w + offset [1 ])
126+ center = (0.5 * (p1 [0 ]+ p3 [0 ]),0.5 * (p1 [1 ]+ p3 [1 ]))
127+ detections .append ((center , (w ,h ),- 1 * angle * 180.0 / math .pi ))
72128confidences .append (float (score ))
73129
74130# Return detections and confidences
75131return [detections ,confidences ]
76132
133+
77134def main ():
78135# Read and store arguments
79136confThreshold = args .thr
80137nmsThreshold = args .nms
81138inpWidth = args .width
82139inpHeight = args .height
83- model = args .model
140+ modelDetector = args .model
141+ modelRecognition = args .ocr
84142
85143# Load network
86- net = cv .dnn .readNet (model )
144+ detector = cv .dnn .readNet (modelDetector )
145+ recognizer = cv .dnn .readNet (modelRecognition )
87146
88147# Create a new named window
89148kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
@@ -95,6 +154,7 @@ def main():
95154# Open a video file or an image file or a camera stream
96155cap = cv .VideoCapture (args .input if args .input else 0 )
97156
157+ tickmeter = cv .TickMeter ()
98158while cv .waitKey (1 )< 0 :
99159# Read frame
100160hasFrame ,frame = cap .read ()
@@ -111,36 +171,61 @@ def main():
111171# Create a 4D blob from frame.
112172blob = cv .dnn .blobFromImage (frame ,1.0 , (inpWidth ,inpHeight ), (123.68 ,116.78 ,103.94 ),True ,False )
113173
114- # Run the model
115- net .setInput (blob )
116- outs = net .forward (outNames )
117- t ,_ = net .getPerfProfile ()
118- label = 'Inference time: %.2f ms' % (t * 1000.0 / cv .getTickFrequency ())
174+ # Run the detection model
175+ detector .setInput (blob )
176+
177+ tickmeter .start ()
178+ outs = detector .forward (outNames )
179+ tickmeter .stop ()
119180
120181# Get scores and geometry
121182scores = outs [0 ]
122183geometry = outs [1 ]
123- [boxes ,confidences ]= decode (scores ,geometry ,confThreshold )
184+ [boxes ,confidences ]= decodeBoundingBoxes (scores ,geometry ,confThreshold )
124185
125186# Apply NMS
126- indices = cv .dnn .NMSBoxesRotated (boxes ,confidences ,confThreshold ,nmsThreshold )
187+ indices = cv .dnn .NMSBoxesRotated (boxes ,confidences ,confThreshold ,nmsThreshold )
127188for i in indices :
128189# get 4 corners of the rotated rect
129190vertices = cv .boxPoints (boxes [i [0 ]])
130191# scale the bounding box coordinates based on the respective ratios
131192for j in range (4 ):
132193vertices [j ][0 ]*= rW
133194vertices [j ][1 ]*= rH
195+
196+
197+ # get cropped image using perspective transform
198+ if modelRecognition :
199+ cropped = fourPointsTransform (frame ,vertices )
200+ cropped = cv .cvtColor (cropped ,cv .COLOR_BGR2GRAY )
201+
202+ # Create a 4D blob from cropped image
203+ blob = cv .dnn .blobFromImage (cropped ,size = (100 ,32 ),mean = 127.5 ,scalefactor = 1 / 127.5 )
204+ recognizer .setInput (blob )
205+
206+ # Run the recognition model
207+ tickmeter .start ()
208+ result = recognizer .forward ()
209+ tickmeter .stop ()
210+
211+ # decode the result into text
212+ wordRecognized = decodeText (result )
213+ cv .putText (frame ,wordRecognized , (int (vertices [1 ][0 ]),int (vertices [1 ][1 ])),cv .FONT_HERSHEY_SIMPLEX ,
214+ 0.5 , (255 ,0 ,0 ))
215+
134216for j in range (4 ):
135217p1 = (vertices [j ][0 ],vertices [j ][1 ])
136218p2 = (vertices [(j + 1 )% 4 ][0 ],vertices [(j + 1 )% 4 ][1 ])
137219cv .line (frame ,p1 ,p2 , (0 ,255 ,0 ),1 )
138220
139221# Put efficiency information
222+ label = 'Inference time: %.2f ms' % (tickmeter .getTimeMilli ())
140223cv .putText (frame ,label , (0 ,15 ),cv .FONT_HERSHEY_SIMPLEX ,0.5 , (0 ,255 ,0 ))
141224
142225# Display the frame
143- cv .imshow (kWinName ,frame )
226+ cv .imshow (kWinName ,frame )
227+ tickmeter .reset ()
228+
144229
145230if __name__ == "__main__" :
146231main ()