train_and_predict_5.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. #-*- coding:utf-8 -*
  2. #You may need to restart your runtime prior to this, to let your installation take effect
  3. # Some basic setup
  4. import cv2
  5. from detectron2.engine import DefaultPredictor
  6. from detectron2.config import get_cfg
  7. import os
  8. import time
  9. from datetime import datetime
  10. import torch
  11. import numpy as np
  12. import logging
  13. import logging.handlers
  14. import camera_configs
  15. from mysql import mysql
  16. from multiprocessing import Process, Queue
  17. from img_preprocess import BGR_stretching
  18. from video_capture import img_capture, LOG_LEVEL, BINOCULAR_ID
  19. from websocket_server import push_to_web
  20. import json
  21. import ast
  22. from configparser import ConfigParser
  23. config = ConfigParser()
  24. config.read(r'./config.ini',encoding='utf-8')
  25. #鱼种类映射表
  26. FISH_SPECIES_MAP = ast.literal_eval(config['species']['FISH_SPECIES_MAP'])
  27. #检测后的图像存储路径
  28. IMAGE_PATH = config['path']['IMAGE_PATH']
  29. if not os.path.exists(IMAGE_PATH):
  30. os.mkdir(IMAGE_PATH)
  31. #Mysql
  32. MYSQL_HOST = config['mysql']['MYSQL_HOST']
  33. MYSQL_PORT = int(config['mysql']['MYSQL_PORT'])
  34. MYSQL_USER = config['mysql']['MYSQL_USER']
  35. MYSQL_PASSWORD = config['mysql']['MYSQL_PASSWORD']
  36. MYSQL_DATABASE = config['mysql']['MYSQL_DATABASE']
  37. MYSQL_HOST_LOCAL = '127.0.0.1'
  38. def logger_init(LOG_LEVEL):
  39. logger = logging.getLogger('mylogger')
  40. logger.setLevel(LOG_LEVEL)
  41. formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
  42. #backup log is 3, max to 1MB. predictor.log3/predictor.log2/predictor.log1: oldest to newest
  43. file_handler = logging.handlers.RotatingFileHandler('master.log',mode='a',maxBytes=1024*1024,backupCount=5)
  44. file_handler.setFormatter(formatter)
  45. logger.addHandler(file_handler)
  46. return logger
  47. def fasterrcnn_init():
  48. global predictor_fasterrcnn
  49. cfg_fasterrcnn = get_cfg()
  50. cfg_fasterrcnn.merge_from_file("configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
  51. cfg_fasterrcnn.OUTPUT_DIR='./weights'
  52. cfg_fasterrcnn.MODEL.WEIGHTS = os.path.join(cfg_fasterrcnn.OUTPUT_DIR,'model_final_fasterrcnn.pth')
  53. #set score_thresh_test to change the detect threshold value.
  54. cfg_fasterrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set the testing threshold for this model
  55. cfg_fasterrcnn.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (128)
  56. cfg_fasterrcnn.MODEL.ROI_HEADS.NUM_CLASSES = 2
  57. #cfg_fasterrcnn.MODEL.DEVICE = 'cpu'
  58. predictor_fasterrcnn = DefaultPredictor(cfg_fasterrcnn)
  59. def whole_filter_new(image, insseg_output):
  60. #this functio is to filter boxes which include a whole fish
  61. #parameter:
  62. # image_orig: original image
  63. # masked_img_list: list of box images that returned by get_masked_box_img().
  64. #record whether remain the image in masked_img_list, 1: remain, 0:remove.
  65. confident_list=[]
  66. boxes = insseg_output['instances'].pred_boxes.tensor.clone().detach().cpu().numpy()
  67. #save heads and tails center
  68. heads_center = []
  69. tails_center = []
  70. #global predictor_fasterrcnn
  71. output = predictor_fasterrcnn(image)
  72. #no object in output
  73. if len(output['instances']) == 0:
  74. return None
  75. #don't define local variable to avoid memory increase
  76. #fasterrcnn_boxes = output['instances'].pred_boxes.tensor
  77. #fasterrcnn_classes = output['instances'].pred_classes
  78. #collect head and tail center to heads_center and tails_center
  79. for i,box in enumerate(output['instances'].pred_boxes.tensor):
  80. #0 is head
  81. if output['instances'].pred_classes[i] == 0:
  82. heads_center.append((((box[0]+box[2])/2).item(),((box[1]+box[3])/2).item()))
  83. #1 is tail
  84. else:
  85. tails_center.append((((box[0]+box[2])/2).item(),((box[1]+box[3])/2).item()))
  86. #print(heads_center,tails_center)
  87. for box in boxes:
  88. if contains_whole_fish(box,heads_center,tails_center):
  89. confident_list.append(1)
  90. else:
  91. confident_list.append(0)
  92. #no whole fish, return None
  93. if 1 not in confident_list:
  94. return None
  95. return confident_list
  96. def contains_whole_fish(box, heads_center, tails_center):
  97. #this function is to check if a box contains fish head and tail
  98. #box: maskrcnn output['instances'].pred_boxes.tensor
  99. #heads_center, tails_center: [(x1,y1),(x2,y2),...]
  100. if contains_point(box, heads_center) and contains_point(box, tails_center):
  101. return True
  102. return False
  103. def contains_point(box, point_list):
  104. #this function is to check if box contains a point
  105. #box: maskrcnn output['instances'].pred_boxes.tensor
  106. #point_list: [(x1,y1),(x2,y2),...]
  107. for point in point_list:
  108. if box[0] <= point[0] <= box[2] and box[1] <= point[1] <= box[3]:
  109. return True
  110. return False
  111. def get_classes_list(confident_list, output):
  112. #this function is to construct classes list for whole fish.
  113. #because contours_total(get_contours_new()) and confident_list corresponding relation is changed,
  114. #so reconstruct classes list to trace a contour belong to which class.
  115. # confident_list: output of whole_filter_new()
  116. # output: maskrcnn output
  117. class_list = []
  118. for i, confident in enumerate(confident_list):
  119. if confident == 1:
  120. class_list.append(output['instances'].pred_classes[i].clone().detach().cpu().item())
  121. return class_list
  122. def get_contours_new(confident_list, output):
  123. #this function is to get contours from maskrcnn output
  124. # confident_list: index of whole fish
  125. # output: output of maskrcnn
  126. contours_total = []
  127. #normally output['instances'] will not be 0 length
  128. if len(output['instances']) < 1 or len(confident_list) < 1:
  129. return None
  130. for i, confident in enumerate(confident_list):
  131. if confident == 1:
  132. binary_img = output['instances'].pred_masks[i].clone().detach().cpu().numpy().astype(np.uint8)*255
  133. contours, hierarchy = cv2.findContours(binary_img,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)#cv2.CHAIN_APPROX_SIMPLE)
  134. contours_total.extend(contours)
  135. return contours_total
  136. def distance_measure(left_point1,left_point2,right_point1,right_point2):
  137. #this function is to calculate fish length from left camera and right camera.
  138. location1 = np.array([[[left_point1[0],left_point1[1],left_point1[0] - right_point1[0]]]], dtype=np.float32)
  139. XYZ1 = cv2.perspectiveTransform(location1,camera_configs.Q)
  140. location2 = np.array([[[left_point2[0],left_point2[1],left_point2[0] - right_point2[0]]]], dtype=np.float32)
  141. XYZ2 = cv2.perspectiveTransform(location2,camera_configs.Q)
  142. distance = ((XYZ2[0][0][0] - XYZ1[0][0][0])**2 + (XYZ2[0][0][1] - XYZ1[0][0][1])**2 + (XYZ2[0][0][2] - XYZ1[0][0][2])**2)**0.5
  143. return distance
  144. def get_point_pair(rect_l,c_l,rect_r,c_r):
  145. #this function is to correspond points on left and right img
  146. #rect_l,rect_r: left and right minAreaRect, result of cv2.minAreaRect()
  147. #c_l,c_r: left and right contours, result of cv2.findContours()
  148. #limit on horizontal direction水平对齐限制,水平方向相差不能大于10像素
  149. if abs(rect_l[0][1] - rect_r[0][1]) > 10:
  150. logger.info('未水平对齐')
  151. #print('未水平对齐\n')
  152. return None
  153. #rotation angle limit,旋转角度限制,左右目最小外接矩形的旋转角度差不能大于10度
  154. if abs(rect_l[2]-rect_r[2]) > 10:
  155. logger.info('旋转角度不匹配')
  156. #print('旋转角度不匹配\n')
  157. return None
  158. #多边形逼近轮廓
  159. epsilon_l = rect_l[1][0]/2 if rect_l[1][0]>rect_l[1][1] else rect_l[1][1]/2
  160. epsilon_r = rect_r[1][0]/2 if rect_r[1][0]>rect_r[1][1] else rect_r[1][1]/2
  161. #设置epsilon参数,使approxPolyDP函数停在第一步,即找到最远的两个点
  162. approx_l = cv2.approxPolyDP(c_l,epsilon_l,True)
  163. approx_r = cv2.approxPolyDP(c_r,epsilon_r,True)
  164. #find more than two points, return None
  165. if approx_l.shape[0] != 2 or approx_r.shape[0] != 2:
  166. return None
  167. points = []
  168. if approx_l[0][0][0] < approx_l[1][0][0]:
  169. points.append(approx_l[0][0])
  170. points.append(approx_l[1][0])
  171. else:
  172. points.append(approx_l[1][0])
  173. points.append(approx_l[0][0])
  174. if approx_r[0][0][0] < approx_r[1][0][0]:
  175. points.append(approx_r[0][0])
  176. points.append(approx_r[1][0])
  177. else:
  178. points.append(approx_r[1][0])
  179. points.append(approx_r[0][0])
  180. return points
  181. def contours_match_new(contours_L, contours_R, classes_list_L, img_L_name='', img_R_name=''):
  182. #this function is to get correspond points on left and right img in order to calculate fish length.
  183. #contours_L, contours_R: left and right contours(whole fish contours).
  184. #classes_list_L: the whole fish classes
  185. #to save points pair
  186. points = []
  187. #to save the contour fish class of left img which match success
  188. classes_list_new = []
  189. #contours_L_matched is a subset of contours_L
  190. contours_L_matched = []
  191. #loop every contour on left img to match on right img
  192. for j,c_l in enumerate(contours_L):
  193. #best match score, smaller is better
  194. best_match = 100
  195. #best match index on right img
  196. best_match_index = 0
  197. #best match exist or not
  198. best_match_exist = False
  199. #calculate match score on left img
  200. for i,c_r in enumerate(contours_R):
  201. match = cv2.matchShapes(c_l,c_r,1,0.0)
  202. #print('相似度:',j,i,match)
  203. #method 2, set threshold 0.5, limit similarity that below 0.5, if more than one similarity below 0.5, select the best one.
  204. #method 1, set threshold 0.3.
  205. if match < 0.15 and match < best_match:
  206. #maybe the same fish
  207. best_match_index = i
  208. best_match = match
  209. best_match_exist = True
  210. if not best_match_exist:
  211. #most similar shape is not exist
  212. logger.info('index {} in contours_L is not matched'.format(j))
  213. #print('index {} in contours_L is not matched'.format(j))
  214. continue
  215. #get the best contour on right img
  216. c_r = contours_R[best_match_index]
  217. #get min Area rect
  218. rect_l = cv2.minAreaRect(c_l)
  219. rect_r = cv2.minAreaRect(c_r)
  220. p = get_point_pair(rect_l,c_l,rect_r,c_r)
  221. if p == None:
  222. continue
  223. points.append(p)
  224. classes_list_new.append(classes_list_L[j])
  225. contours_L_matched.append(c_l)
  226. #points is like: [[p1,p2,p3,p4],...]
  227. #classes_list_new corresponds to points, they have same length
  228. #contours_L_match_result has the same length with contours_L
  229. return points, classes_list_new, contours_L_matched
  230. def get_length(points):
  231. #this function is to calculate fish length
  232. #points: [[(),(),(),()],...]
  233. #to save every class fish len
  234. fish_length = []
  235. for i,p in enumerate(points):
  236. #calculate fish len
  237. fish_len = distance_measure(p[0],p[1],p[2],p[3])
  238. #if len more than 500mm, wrong value
  239. if fish_len > 650:
  240. logger.warning('fish_len more than 650mm!')
  241. continue
  242. fish_length.append(round(fish_len,1))
  243. return fish_length
  244. def get_weight(fish_length):
  245. #get fish weight according to fish length.公式:W=aL^b, a=0.0149, b=3.0265
  246. #fish_length:鱼尺寸列表
  247. fish_weight = []
  248. for length in fish_length:
  249. #先将length转换成cm再计算
  250. #weight = 0.0149*pow(length/10,3.0625)
  251. weight = 0.01807*pow(length/10,3.1106)
  252. #重量单位是g
  253. fish_weight.append(round(weight,1))
  254. return fish_weight
  255. def save_image(image_L, image_R, contours_L_matched, current_time):
  256. #image_L:左目图像
  257. #contours_L_matched:左图匹配成功的轮廓,是contours_L的子集
  258. #current_time:当前时间
  259. # 临时注释,保存一些原始图片
  260. #for i,contour in enumerate(contours_L_matched):
  261. # cv2.drawContours(image_L,[contour],-1,(0,0,255),2)
  262. # x,y,w,h = cv2.boundingRect(contour)
  263. # cv2.putText(image_L,str(i+1),(x,y),cv2.FONT_HERSHEY_SIMPLEX,2,(0,0,255),2)
  264. #change '2021-10-09 12:11:10' to '2021-10-09-12-11-10'
  265. current_time = current_time.replace(' ','-').replace(':','-')
  266. cv2.imwrite(os.path.join(IMAGE_PATH,current_time+'.jpg'),image_L)
  267. return image_L
  268. def put_image(image,species,length,weight,Q_masked_img):
  269. #推送图像及识别到的鱼的信息推送到前端
  270. #image: 画上鱼轮廓的图像
  271. #species:鱼种类列表
  272. #length:鱼长度列表,单位mm
  273. #weight:鱼重量列表,单位g
  274. #Q_masked_img:队列
  275. json_list = []
  276. for i,s in enumerate(species):
  277. item = {}
  278. item['fish_id'] = i+1
  279. item['species'] = s
  280. item['size'] = length[i]
  281. item['weight'] = weight[i]#round(random.random()*10,1)
  282. json_list.append(item)
  283. try:
  284. print('json realtime:',json.dumps(json_list))
  285. Q_masked_img.put_nowait([image, json.dumps(json_list)])
  286. except Exception as e:
  287. logger.warning(str(e))
  288. def get_species(classes_list_L_matched):
  289. #get the species name
  290. #classes_list_L_matched: the class of whole matched contours, like:[1,0,2,1,0,3]
  291. fish_species = []
  292. for i in classes_list_L_matched:
  293. fish_species.append(FISH_SPECIES_MAP[i])
  294. return fish_species
  295. def construct_result(fish_species,fish_length,fish_weight):
  296. #construct result save to mysql
  297. #fish_species: ['crucian', 'sebastes', 'crucian',...]
  298. #fish_length: [25.22, 18.35, 23.56, ...]
  299. #fish_weight: [199.2, 298.3, 326.6, ...]
  300. result = []
  301. item = []
  302. for i,length in enumerate(fish_length):
  303. item.append(i+1)#fish_id
  304. item.append(fish_species[i])
  305. item.append(length)
  306. item.append(fish_weight[i])#(random.random()*10)#fish_weight
  307. result.append(item.copy())
  308. item.clear()
  309. return result
  310. def insert_mysql_new(result,current_time):
  311. #result: the output of construct_result, [fish_id,fish_species,fish_size,fish_weight]
  312. #the table culumns is:datetime,binocular_id,image_path,fish_id,fish_species,fish_size,fish_weight,species_add;
  313. #species_add is edited on web. the result contains four items, the other three items is common.
  314. sql = "INSERT INTO `binocular_data` (`datetime`,`binocular_id`,`fish_id`,`fish_species`,`fish_size`,`fish_weight`) VALUES (%s,%s,%s,%s,%s,%s)"
  315. for r in result:
  316. success = my.insert(sql,(current_time,BINOCULAR_ID,r[0],r[1],r[2],r[3]))
  317. if not success:
  318. logger.warning('insert to mysql failed')
  319. #print('insert to mysql failed')
  320. success = my_local.insert(sql,(current_time,BINOCULAR_ID,r[0],r[1],r[2],r[3]))
  321. if not success:
  322. logger.warning('insert to local mysql failed')
  323. def maskrcnn_init():
  324. global predictor_maskrcnn
  325. cfg_maskrcnn = get_cfg()
  326. cfg_maskrcnn.merge_from_file("./configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
  327. cfg_maskrcnn.OUTPUT_DIR='./weights'
  328. cfg_maskrcnn.MODEL.WEIGHTS = os.path.join(cfg_maskrcnn.OUTPUT_DIR,'model_0008749good_maskrcnn.pth')
  329. cfg_maskrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set the testing threshold for this model
  330. cfg_maskrcnn.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (128)
  331. cfg_maskrcnn.MODEL.ROI_HEADS.NUM_CLASSES = 2
  332. cfg_maskrcnn.MODEL.DEVICE = 'cpu'
  333. predictor_maskrcnn = DefaultPredictor(cfg_maskrcnn)
  334. def predict_new(Q_binocular,Q_masked_img):
  335. #从文件读图片,仅测试用
  336. #image_L_list = sorted(os.listdir('./dataset/faster_rcnn/images'))
  337. #image_R_list = sorted(os.listdir('./dataset/images_R'))
  338. #cuda清空缓存计数
  339. i=0
  340. while True:
  341. #读取左右相机
  342. try:
  343. image_L, image_R = Q_binocular.get_nowait()
  344. except Exception as e:
  345. logger.warning('no image in queue for binocular')
  346. #print('no image in queue for binocular')
  347. time.sleep(2)
  348. continue
  349. image_L = BGR_stretching(image_L)
  350. image_R = BGR_stretching(image_R)
  351. #实例分割
  352. maskrcnn_output_L=predictor_maskrcnn(image_L)
  353. maskrcnn_output_R=predictor_maskrcnn(image_R)
  354. if len(maskrcnn_output_L['instances'])<1 or len(maskrcnn_output_R['instances'])<1:
  355. logger.info('no instance in image!')
  356. continue
  357. # 完整性过滤,confident_list_L is like: [1,0,1,1,0,0],the length is the same
  358. # with maskrcnn_output_L
  359. confident_list_L = whole_filter_new(image_L, maskrcnn_output_L)
  360. confident_list_R = whole_filter_new(image_R, maskrcnn_output_R)
  361. if confident_list_L == None or confident_list_R == None:
  362. #没有完整的鱼
  363. logger.info('no whole fish in image!')
  364. continue
  365. #获取完整鱼轮廓
  366. contours_L = get_contours_new(confident_list_L, maskrcnn_output_L)
  367. contours_R = get_contours_new(confident_list_R, maskrcnn_output_R)
  368. if len(contours_L) < 1 or len(contours_R) < 1:
  369. logger.info('no contours to match')
  370. continue
  371. if len(contours_L) > confident_list_L.count(1) or len(contours_R) > confident_list_R.count(1):
  372. logger.info('contours num not match')
  373. continue
  374. #获取完整的鱼的类别,左图即可,用左图去循环匹配右图的轮廓;classes_list_L和contours_L长度一致,
  375. #classes_list_L[i]为contours_L[i]所属类别编号,长度和contours_L一致
  376. classes_list_L = get_classes_list(confident_list_L, maskrcnn_output_L)
  377. #以左图为锚,循环匹配右图轮廓
  378. # points是最终要测量的点,calsses_list_matched是最终的类别编号,
  379. # contours_L_matched是contours_L的子集,是最终左右目匹配成功的轮廓
  380. points, classes_list_L_matched, contours_L_matched = contours_match_new(contours_L,contours_R, classes_list_L)
  381. #print('points: ', points)
  382. if len(points)<1:
  383. logger.info('no match contours')
  384. #print('no match contours')
  385. continue
  386. #计算长度
  387. #fish_length is like:[185.11, 192.99, 255.01],长度和points一致
  388. #如果检测到一条鱼大于650mm,会导致fish_length长度小于points
  389. fish_length = get_length(points)
  390. if len(fish_length) != len(points):
  391. logger.warning('fish lenght is more than 650mm')
  392. continue
  393. logger.info(str(i)+'fish_length:'+str(fish_length))
  394. #print('length: ',fish_length)
  395. fish_weight = get_weight(fish_length)
  396. current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-5]
  397. #save image_L with name like '2021-10-10-12-10-11.jpg' and draw contours and fish id.
  398. #there is a crontab job to delete image, see /home/sencott/delete_history_image.sh
  399. image_with_contours = save_image(image_L,image_R,contours_L_matched,current_time)
  400. #cv2.imwrite('masked.jpg',image_with_contours)
  401. #get fish sepcies
  402. fish_species = get_species(classes_list_L_matched)
  403. #put image drawed contours and fish info to queue
  404. put_image(image_with_contours,fish_species,fish_length,fish_weight,Q_masked_img)
  405. #construct result save to mysql
  406. result = construct_result(fish_species,fish_length,fish_weight)
  407. insert_mysql_new(result,current_time)
  408. #empty cuda cache every 1000 image
  409. i += 1
  410. if i == 1000:
  411. torch.cuda.empty_cache()
  412. i = 0
  413. time.sleep(2)
  414. logger = logger_init(LOG_LEVEL)
  415. fasterrcnn_init()
  416. maskrcnn_init()
  417. #两个队列,分别存放推送到前端的图像和双目处理的图像
  418. queue_camera_to_web = Queue(10)
  419. queue_camera_to_net= Queue(10)
  420. #用于存放识别后的图像的队列,predict_new线程和webskt_server进程通信使用
  421. queue_masked_img = Queue(10)
  422. #存放相机设置参数
  423. queue_param = Queue(1)
  424. my = mysql(MYSQL_HOST, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE, MYSQL_PORT)
  425. my_local = mysql(MYSQL_HOST_LOCAL, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE, MYSQL_PORT)
  426. if __name__=='__main__':
  427. process1 = Process(target=img_capture,args=(queue_camera_to_web,queue_camera_to_net,queue_param))
  428. process1.daemon=True
  429. process1.start()
  430. process2 = Process(target=push_to_web,args=(queue_camera_to_web,queue_masked_img,queue_param))
  431. process2.daemon=True
  432. process2.start()
  433. predict_new(queue_camera_to_net,queue_masked_img)