Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1b336bb

Browse files
Merge pull request#16955 from themechanicalcoder:text_recognition
* add text recognition sample* fix pylint warning* made changes according to the c++ example* fix errors* add text recognition sample* update text detection sample
1 parent0fb3b8d commit1b336bb

File tree

1 file changed

+107
-22
lines changed

1 file changed

+107
-22
lines changed

‎samples/dnn/text_detection.py‎

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,81 @@
1+
'''
2+
Text detection model: https://github.com/argman/EAST
3+
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
4+
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
5+
How to convert from pb to onnx:
6+
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
7+
import torch
8+
import models.crnn as CRNN
9+
model = CRNN(32, 1, 37, 256)
10+
model.load_state_dict(torch.load('crnn.pth'))
11+
dummy_input = torch.randn(1, 1, 32, 100)
12+
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
13+
'''
14+
15+
116
# Import required modules
17+
importnumpyasnp
218
importcv2ascv
319
importmath
420
importargparse
521

622
############ Add argument parser for command line arguments ############
7-
parser=argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)')
8-
parser.add_argument('--input',help='Path to input image or video file. Skip this argument to capture frames from a camera.')
9-
parser.add_argument('--model',required=True,
10-
help='Path to a binary .pb file of model contains trained weights.')
23+
parser=argparse.ArgumentParser(
24+
description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
25+
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
26+
"The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch")
27+
parser.add_argument('--input',
28+
help='Path to input image or video file. Skip this argument to capture frames from a camera.')
29+
parser.add_argument('--model','-m',required=True,
30+
help='Path to a binary .pb file contains trained detector network.')
31+
parser.add_argument('--ocr',default="crnn.onnx",
32+
help="Path to a binary .pb or .onnx file contains trained recognition network", )
1133
parser.add_argument('--width',type=int,default=320,
1234
help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
13-
parser.add_argument('--height',type=int,default=320,
35+
parser.add_argument('--height',type=int,default=320,
1436
help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
15-
parser.add_argument('--thr',type=float,default=0.5,
37+
parser.add_argument('--thr',type=float,default=0.5,
1638
help='Confidence threshold.')
17-
parser.add_argument('--nms',type=float,default=0.4,
39+
parser.add_argument('--nms',type=float,default=0.4,
1840
help='Non-maximum suppression threshold.')
1941
args=parser.parse_args()
2042

43+
2144
############ Utility functions ############
22-
defdecode(scores,geometry,scoreThresh):
45+
46+
deffourPointsTransform(frame,vertices):
47+
vertices=np.asarray(vertices)
48+
outputSize= (100,32)
49+
targetVertices=np.array([
50+
[0,outputSize[1]-1],
51+
[0,0],
52+
[outputSize[0]-1,0],
53+
[outputSize[0]-1,outputSize[1]-1]],dtype="float32")
54+
55+
rotationMatrix=cv.getPerspectiveTransform(vertices,targetVertices)
56+
result=cv.warpPerspective(frame,rotationMatrix,outputSize)
57+
returnresult
58+
59+
60+
defdecodeText(scores):
61+
text=""
62+
alphabet="0123456789abcdefghijklmnopqrstuvwxyz"
63+
foriinrange(scores.shape[0]):
64+
c=np.argmax(scores[i][0])
65+
ifc!=0:
66+
text+=alphabet[c-1]
67+
else:
68+
text+='-'
69+
70+
# adjacent same letters as well as background text must be removed to get the final output
71+
char_list= []
72+
foriinrange(len(text)):
73+
iftext[i]!='-'and (not (i>0andtext[i]==text[i-1])):
74+
char_list.append(text[i])
75+
return''.join(char_list)
76+
77+
78+
defdecodeBoundingBoxes(scores,geometry,scoreThresh):
2379
detections= []
2480
confidences= []
2581

@@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
47103
score=scoresData[x]
48104

49105
# If score is lower than threshold score, move to next x
50-
if(score<scoreThresh):
106+
if(score<scoreThresh):
51107
continue
52108

53109
# Calculate offset
@@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
66122

67123
# Find points for rectangle
68124
p1= (-sinA*h+offset[0],-cosA*h+offset[1])
69-
p3= (-cosA*w+offset[0],sinA*w+offset[1])
70-
center= (0.5*(p1[0]+p3[0]),0.5*(p1[1]+p3[1]))
71-
detections.append((center, (w,h),-1*angle*180.0/math.pi))
125+
p3= (-cosA*w+offset[0],sinA*w+offset[1])
126+
center= (0.5*(p1[0]+p3[0]),0.5*(p1[1]+p3[1]))
127+
detections.append((center, (w,h),-1*angle*180.0/math.pi))
72128
confidences.append(float(score))
73129

74130
# Return detections and confidences
75131
return [detections,confidences]
76132

133+
77134
defmain():
78135
# Read and store arguments
79136
confThreshold=args.thr
80137
nmsThreshold=args.nms
81138
inpWidth=args.width
82139
inpHeight=args.height
83-
model=args.model
140+
modelDetector=args.model
141+
modelRecognition=args.ocr
84142

85143
# Load network
86-
net=cv.dnn.readNet(model)
144+
detector=cv.dnn.readNet(modelDetector)
145+
recognizer=cv.dnn.readNet(modelRecognition)
87146

88147
# Create a new named window
89148
kWinName="EAST: An Efficient and Accurate Scene Text Detector"
@@ -95,6 +154,7 @@ def main():
95154
# Open a video file or an image file or a camera stream
96155
cap=cv.VideoCapture(args.inputifargs.inputelse0)
97156

157+
tickmeter=cv.TickMeter()
98158
whilecv.waitKey(1)<0:
99159
# Read frame
100160
hasFrame,frame=cap.read()
@@ -111,36 +171,61 @@ def main():
111171
# Create a 4D blob from frame.
112172
blob=cv.dnn.blobFromImage(frame,1.0, (inpWidth,inpHeight), (123.68,116.78,103.94),True,False)
113173

114-
# Run the model
115-
net.setInput(blob)
116-
outs=net.forward(outNames)
117-
t,_=net.getPerfProfile()
118-
label='Inference time: %.2f ms'% (t*1000.0/cv.getTickFrequency())
174+
# Run the detection model
175+
detector.setInput(blob)
176+
177+
tickmeter.start()
178+
outs=detector.forward(outNames)
179+
tickmeter.stop()
119180

120181
# Get scores and geometry
121182
scores=outs[0]
122183
geometry=outs[1]
123-
[boxes,confidences]=decode(scores,geometry,confThreshold)
184+
[boxes,confidences]=decodeBoundingBoxes(scores,geometry,confThreshold)
124185

125186
# Apply NMS
126-
indices=cv.dnn.NMSBoxesRotated(boxes,confidences,confThreshold,nmsThreshold)
187+
indices=cv.dnn.NMSBoxesRotated(boxes,confidences,confThreshold,nmsThreshold)
127188
foriinindices:
128189
# get 4 corners of the rotated rect
129190
vertices=cv.boxPoints(boxes[i[0]])
130191
# scale the bounding box coordinates based on the respective ratios
131192
forjinrange(4):
132193
vertices[j][0]*=rW
133194
vertices[j][1]*=rH
195+
196+
197+
# get cropped image using perspective transform
198+
ifmodelRecognition:
199+
cropped=fourPointsTransform(frame,vertices)
200+
cropped=cv.cvtColor(cropped,cv.COLOR_BGR2GRAY)
201+
202+
# Create a 4D blob from cropped image
203+
blob=cv.dnn.blobFromImage(cropped,size=(100,32),mean=127.5,scalefactor=1/127.5)
204+
recognizer.setInput(blob)
205+
206+
# Run the recognition model
207+
tickmeter.start()
208+
result=recognizer.forward()
209+
tickmeter.stop()
210+
211+
# decode the result into text
212+
wordRecognized=decodeText(result)
213+
cv.putText(frame,wordRecognized, (int(vertices[1][0]),int(vertices[1][1])),cv.FONT_HERSHEY_SIMPLEX,
214+
0.5, (255,0,0))
215+
134216
forjinrange(4):
135217
p1= (vertices[j][0],vertices[j][1])
136218
p2= (vertices[(j+1)%4][0],vertices[(j+1)%4][1])
137219
cv.line(frame,p1,p2, (0,255,0),1)
138220

139221
# Put efficiency information
222+
label='Inference time: %.2f ms'% (tickmeter.getTimeMilli())
140223
cv.putText(frame,label, (0,15),cv.FONT_HERSHEY_SIMPLEX,0.5, (0,255,0))
141224

142225
# Display the frame
143-
cv.imshow(kWinName,frame)
226+
cv.imshow(kWinName,frame)
227+
tickmeter.reset()
228+
144229

145230
if__name__=="__main__":
146231
main()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp