123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488 |
- #-*- coding:utf-8 -*
- #You may need to restart your runtime prior to this, to let your installation take effect
- # Some basic setup
- import cv2
- from detectron2.engine import DefaultPredictor
- from detectron2.config import get_cfg
- import os
- import time
- from datetime import datetime
- import torch
- import numpy as np
- import logging
- import logging.handlers
- import camera_configs
- from mysql import mysql
- from multiprocessing import Process, Queue
- from img_preprocess import BGR_stretching
- from video_capture import img_capture, LOG_LEVEL, BINOCULAR_ID
- from websocket_server import push_to_web
- import json
- import ast
- from configparser import ConfigParser
- config = ConfigParser()
- config.read(r'./config.ini',encoding='utf-8')
- #鱼种类映射表
- FISH_SPECIES_MAP = ast.literal_eval(config['species']['FISH_SPECIES_MAP'])
- #检测后的图像存储路径
- IMAGE_PATH = config['path']['IMAGE_PATH']
- if not os.path.exists(IMAGE_PATH):
- os.mkdir(IMAGE_PATH)
- #Mysql
- MYSQL_HOST = config['mysql']['MYSQL_HOST']
- MYSQL_PORT = int(config['mysql']['MYSQL_PORT'])
- MYSQL_USER = config['mysql']['MYSQL_USER']
- MYSQL_PASSWORD = config['mysql']['MYSQL_PASSWORD']
- MYSQL_DATABASE = config['mysql']['MYSQL_DATABASE']
- MYSQL_HOST_LOCAL = '127.0.0.1'
- def logger_init(LOG_LEVEL):
- logger = logging.getLogger('mylogger')
- logger.setLevel(LOG_LEVEL)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
- #backup log is 3, max to 1MB. predictor.log3/predictor.log2/predictor.log1: oldest to newest
- file_handler = logging.handlers.RotatingFileHandler('master.log',mode='a',maxBytes=1024*1024,backupCount=5)
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- return logger
- def fasterrcnn_init():
- global predictor_fasterrcnn
- cfg_fasterrcnn = get_cfg()
- cfg_fasterrcnn.merge_from_file("configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
- cfg_fasterrcnn.OUTPUT_DIR='./weights'
- cfg_fasterrcnn.MODEL.WEIGHTS = os.path.join(cfg_fasterrcnn.OUTPUT_DIR,'model_final_fasterrcnn.pth')
- #set score_thresh_test to change the detect threshold value.
- cfg_fasterrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set the testing threshold for this model
- cfg_fasterrcnn.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (128)
- cfg_fasterrcnn.MODEL.ROI_HEADS.NUM_CLASSES = 2
- #cfg_fasterrcnn.MODEL.DEVICE = 'cpu'
- predictor_fasterrcnn = DefaultPredictor(cfg_fasterrcnn)
- def whole_filter_new(image, insseg_output):
- #this functio is to filter boxes which include a whole fish
- #parameter:
- # image_orig: original image
- # masked_img_list: list of box images that returned by get_masked_box_img().
- #record whether remain the image in masked_img_list, 1: remain, 0:remove.
- confident_list=[]
- boxes = insseg_output['instances'].pred_boxes.tensor.clone().detach().cpu().numpy()
- #save heads and tails center
- heads_center = []
- tails_center = []
- #global predictor_fasterrcnn
- output = predictor_fasterrcnn(image)
- #no object in output
- if len(output['instances']) == 0:
- return None
- #don't define local variable to avoid memory increase
- #fasterrcnn_boxes = output['instances'].pred_boxes.tensor
- #fasterrcnn_classes = output['instances'].pred_classes
- #collect head and tail center to heads_center and tails_center
- for i,box in enumerate(output['instances'].pred_boxes.tensor):
- #0 is head
- if output['instances'].pred_classes[i] == 0:
- heads_center.append((((box[0]+box[2])/2).item(),((box[1]+box[3])/2).item()))
- #1 is tail
- else:
- tails_center.append((((box[0]+box[2])/2).item(),((box[1]+box[3])/2).item()))
- #print(heads_center,tails_center)
- for box in boxes:
- if contains_whole_fish(box,heads_center,tails_center):
- confident_list.append(1)
- else:
- confident_list.append(0)
- #no whole fish, return None
- if 1 not in confident_list:
- return None
- return confident_list
- def contains_whole_fish(box, heads_center, tails_center):
- #this function is to check if a box contains fish head and tail
- #box: maskrcnn output['instances'].pred_boxes.tensor
- #heads_center, tails_center: [(x1,y1),(x2,y2),...]
- if contains_point(box, heads_center) and contains_point(box, tails_center):
- return True
- return False
- def contains_point(box, point_list):
- #this function is to check if box contains a point
- #box: maskrcnn output['instances'].pred_boxes.tensor
- #point_list: [(x1,y1),(x2,y2),...]
- for point in point_list:
- if box[0] <= point[0] <= box[2] and box[1] <= point[1] <= box[3]:
- return True
- return False
- def get_classes_list(confident_list, output):
- #this function is to construct classes list for whole fish.
- #because contours_total(get_contours_new()) and confident_list corresponding relation is changed,
- #so reconstruct classes list to trace a contour belong to which class.
- # confident_list: output of whole_filter_new()
- # output: maskrcnn output
- class_list = []
- for i, confident in enumerate(confident_list):
- if confident == 1:
- class_list.append(output['instances'].pred_classes[i].clone().detach().cpu().item())
- return class_list
- def get_contours_new(confident_list, output):
- #this function is to get contours from maskrcnn output
- # confident_list: index of whole fish
- # output: output of maskrcnn
- contours_total = []
- #normally output['instances'] will not be 0 length
- if len(output['instances']) < 1 or len(confident_list) < 1:
- return None
- for i, confident in enumerate(confident_list):
- if confident == 1:
- binary_img = output['instances'].pred_masks[i].clone().detach().cpu().numpy().astype(np.uint8)*255
- contours, hierarchy = cv2.findContours(binary_img,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)#cv2.CHAIN_APPROX_SIMPLE)
- contours_total.extend(contours)
-
- return contours_total
- def distance_measure(left_point1,left_point2,right_point1,right_point2):
- #this function is to calculate fish length from left camera and right camera.
- location1 = np.array([[[left_point1[0],left_point1[1],left_point1[0] - right_point1[0]]]], dtype=np.float32)
- XYZ1 = cv2.perspectiveTransform(location1,camera_configs.Q)
- location2 = np.array([[[left_point2[0],left_point2[1],left_point2[0] - right_point2[0]]]], dtype=np.float32)
- XYZ2 = cv2.perspectiveTransform(location2,camera_configs.Q)
- distance = ((XYZ2[0][0][0] - XYZ1[0][0][0])**2 + (XYZ2[0][0][1] - XYZ1[0][0][1])**2 + (XYZ2[0][0][2] - XYZ1[0][0][2])**2)**0.5
- return distance
- def get_point_pair(rect_l,c_l,rect_r,c_r):
- #this function is to correspond points on left and right img
- #rect_l,rect_r: left and right minAreaRect, result of cv2.minAreaRect()
- #c_l,c_r: left and right contours, result of cv2.findContours()
- #limit on horizontal direction水平对齐限制,水平方向相差不能大于10像素
- if abs(rect_l[0][1] - rect_r[0][1]) > 10:
- logger.info('未水平对齐')
- #print('未水平对齐\n')
- return None
- #rotation angle limit,旋转角度限制,左右目最小外接矩形的旋转角度差不能大于10度
- if abs(rect_l[2]-rect_r[2]) > 10:
- logger.info('旋转角度不匹配')
- #print('旋转角度不匹配\n')
- return None
- #多边形逼近轮廓
- epsilon_l = rect_l[1][0]/2 if rect_l[1][0]>rect_l[1][1] else rect_l[1][1]/2
- epsilon_r = rect_r[1][0]/2 if rect_r[1][0]>rect_r[1][1] else rect_r[1][1]/2
- #设置epsilon参数,使approxPolyDP函数停在第一步,即找到最远的两个点
- approx_l = cv2.approxPolyDP(c_l,epsilon_l,True)
- approx_r = cv2.approxPolyDP(c_r,epsilon_r,True)
- #find more than two points, return None
- if approx_l.shape[0] != 2 or approx_r.shape[0] != 2:
- return None
- points = []
- if approx_l[0][0][0] < approx_l[1][0][0]:
- points.append(approx_l[0][0])
- points.append(approx_l[1][0])
- else:
- points.append(approx_l[1][0])
- points.append(approx_l[0][0])
- if approx_r[0][0][0] < approx_r[1][0][0]:
- points.append(approx_r[0][0])
- points.append(approx_r[1][0])
- else:
- points.append(approx_r[1][0])
- points.append(approx_r[0][0])
- return points
- def contours_match_new(contours_L, contours_R, classes_list_L, img_L_name='', img_R_name=''):
- #this function is to get correspond points on left and right img in order to calculate fish length.
- #contours_L, contours_R: left and right contours(whole fish contours).
- #classes_list_L: the whole fish classes
- #to save points pair
- points = []
-
- #to save the contour fish class of left img which match success
- classes_list_new = []
- #contours_L_matched is a subset of contours_L
- contours_L_matched = []
- #loop every contour on left img to match on right img
- for j,c_l in enumerate(contours_L):
- #best match score, smaller is better
- best_match = 100
- #best match index on right img
- best_match_index = 0
- #best match exist or not
- best_match_exist = False
- #calculate match score on left img
- for i,c_r in enumerate(contours_R):
- match = cv2.matchShapes(c_l,c_r,1,0.0)
- #print('相似度:',j,i,match)
- #method 2, set threshold 0.5, limit similarity that below 0.5, if more than one similarity below 0.5, select the best one.
- #method 1, set threshold 0.3.
- if match < 0.15 and match < best_match:
- #maybe the same fish
- best_match_index = i
- best_match = match
- best_match_exist = True
- if not best_match_exist:
- #most similar shape is not exist
- logger.info('index {} in contours_L is not matched'.format(j))
- #print('index {} in contours_L is not matched'.format(j))
- continue
- #get the best contour on right img
- c_r = contours_R[best_match_index]
- #get min Area rect
- rect_l = cv2.minAreaRect(c_l)
- rect_r = cv2.minAreaRect(c_r)
- p = get_point_pair(rect_l,c_l,rect_r,c_r)
- if p == None:
- continue
- points.append(p)
- classes_list_new.append(classes_list_L[j])
- contours_L_matched.append(c_l)
- #points is like: [[p1,p2,p3,p4],...]
- #classes_list_new corresponds to points, they have same length
- #contours_L_match_result has the same length with contours_L
- return points, classes_list_new, contours_L_matched
- def get_length(points):
- #this function is to calculate fish length
- #points: [[(),(),(),()],...]
- #to save every class fish len
- fish_length = []
- for i,p in enumerate(points):
- #calculate fish len
- fish_len = distance_measure(p[0],p[1],p[2],p[3])
-
- #if len more than 500mm, wrong value
- if fish_len > 650:
- logger.warning('fish_len more than 650mm!')
- continue
-
- fish_length.append(round(fish_len,1))
- return fish_length
- def get_weight(fish_length):
- #get fish weight according to fish length.公式:W=aL^b, a=0.0149, b=3.0265
- #fish_length:鱼尺寸列表
- fish_weight = []
- for length in fish_length:
- #先将length转换成cm再计算
- #weight = 0.0149*pow(length/10,3.0625)
- weight = 0.01807*pow(length/10,3.1106)
- #重量单位是g
- fish_weight.append(round(weight,1))
- return fish_weight
- def save_image(image_L, image_R, contours_L_matched, current_time):
- #image_L:左目图像
- #contours_L_matched:左图匹配成功的轮廓,是contours_L的子集
- #current_time:当前时间
-
- # 临时注释,保存一些原始图片
- #for i,contour in enumerate(contours_L_matched):
- # cv2.drawContours(image_L,[contour],-1,(0,0,255),2)
- # x,y,w,h = cv2.boundingRect(contour)
- # cv2.putText(image_L,str(i+1),(x,y),cv2.FONT_HERSHEY_SIMPLEX,2,(0,0,255),2)
-
- #change '2021-10-09 12:11:10' to '2021-10-09-12-11-10'
- current_time = current_time.replace(' ','-').replace(':','-')
- cv2.imwrite(os.path.join(IMAGE_PATH,current_time+'.jpg'),image_L)
- return image_L
- def put_image(image,species,length,weight,Q_masked_img):
- #推送图像及识别到的鱼的信息推送到前端
- #image: 画上鱼轮廓的图像
- #species:鱼种类列表
- #length:鱼长度列表,单位mm
- #weight:鱼重量列表,单位g
- #Q_masked_img:队列
- json_list = []
- for i,s in enumerate(species):
- item = {}
- item['fish_id'] = i+1
- item['species'] = s
- item['size'] = length[i]
- item['weight'] = weight[i]#round(random.random()*10,1)
- json_list.append(item)
- try:
- print('json realtime:',json.dumps(json_list))
- Q_masked_img.put_nowait([image, json.dumps(json_list)])
- except Exception as e:
- logger.warning(str(e))
- def get_species(classes_list_L_matched):
- #get the species name
- #classes_list_L_matched: the class of whole matched contours, like:[1,0,2,1,0,3]
- fish_species = []
- for i in classes_list_L_matched:
- fish_species.append(FISH_SPECIES_MAP[i])
- return fish_species
- def construct_result(fish_species,fish_length,fish_weight):
- #construct result save to mysql
- #fish_species: ['crucian', 'sebastes', 'crucian',...]
- #fish_length: [25.22, 18.35, 23.56, ...]
- #fish_weight: [199.2, 298.3, 326.6, ...]
- result = []
- item = []
- for i,length in enumerate(fish_length):
- item.append(i+1)#fish_id
- item.append(fish_species[i])
- item.append(length)
- item.append(fish_weight[i])#(random.random()*10)#fish_weight
- result.append(item.copy())
- item.clear()
- return result
- def insert_mysql_new(result,current_time):
- #result: the output of construct_result, [fish_id,fish_species,fish_size,fish_weight]
- #the table culumns is:datetime,binocular_id,image_path,fish_id,fish_species,fish_size,fish_weight,species_add;
- #species_add is edited on web. the result contains four items, the other three items is common.
- sql = "INSERT INTO `binocular_data` (`datetime`,`binocular_id`,`fish_id`,`fish_species`,`fish_size`,`fish_weight`) VALUES (%s,%s,%s,%s,%s,%s)"
- for r in result:
- success = my.insert(sql,(current_time,BINOCULAR_ID,r[0],r[1],r[2],r[3]))
- if not success:
- logger.warning('insert to mysql failed')
- #print('insert to mysql failed')
- success = my_local.insert(sql,(current_time,BINOCULAR_ID,r[0],r[1],r[2],r[3]))
- if not success:
- logger.warning('insert to local mysql failed')
- def maskrcnn_init():
- global predictor_maskrcnn
- cfg_maskrcnn = get_cfg()
- cfg_maskrcnn.merge_from_file("./configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
- cfg_maskrcnn.OUTPUT_DIR='./weights'
- cfg_maskrcnn.MODEL.WEIGHTS = os.path.join(cfg_maskrcnn.OUTPUT_DIR,'model_0008749good_maskrcnn.pth')
- cfg_maskrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set the testing threshold for this model
- cfg_maskrcnn.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (128)
- cfg_maskrcnn.MODEL.ROI_HEADS.NUM_CLASSES = 2
- cfg_maskrcnn.MODEL.DEVICE = 'cpu'
- predictor_maskrcnn = DefaultPredictor(cfg_maskrcnn)
- def predict_new(Q_binocular,Q_masked_img):
- #从文件读图片,仅测试用
- #image_L_list = sorted(os.listdir('./dataset/faster_rcnn/images'))
- #image_R_list = sorted(os.listdir('./dataset/images_R'))
- #cuda清空缓存计数
- i=0
- while True:
- #读取左右相机
- try:
- image_L, image_R = Q_binocular.get_nowait()
- except Exception as e:
- logger.warning('no image in queue for binocular')
- #print('no image in queue for binocular')
- time.sleep(2)
- continue
- image_L = BGR_stretching(image_L)
- image_R = BGR_stretching(image_R)
- #实例分割
- maskrcnn_output_L=predictor_maskrcnn(image_L)
- maskrcnn_output_R=predictor_maskrcnn(image_R)
- if len(maskrcnn_output_L['instances'])<1 or len(maskrcnn_output_R['instances'])<1:
- logger.info('no instance in image!')
- continue
- # 完整性过滤,confident_list_L is like: [1,0,1,1,0,0],the length is the same
- # with maskrcnn_output_L
- confident_list_L = whole_filter_new(image_L, maskrcnn_output_L)
- confident_list_R = whole_filter_new(image_R, maskrcnn_output_R)
- if confident_list_L == None or confident_list_R == None:
- #没有完整的鱼
- logger.info('no whole fish in image!')
- continue
- #获取完整鱼轮廓
- contours_L = get_contours_new(confident_list_L, maskrcnn_output_L)
- contours_R = get_contours_new(confident_list_R, maskrcnn_output_R)
- if len(contours_L) < 1 or len(contours_R) < 1:
- logger.info('no contours to match')
- continue
- if len(contours_L) > confident_list_L.count(1) or len(contours_R) > confident_list_R.count(1):
- logger.info('contours num not match')
- continue
-
- #获取完整的鱼的类别,左图即可,用左图去循环匹配右图的轮廓;classes_list_L和contours_L长度一致,
- #classes_list_L[i]为contours_L[i]所属类别编号,长度和contours_L一致
- classes_list_L = get_classes_list(confident_list_L, maskrcnn_output_L)
-
- #以左图为锚,循环匹配右图轮廓
- # points是最终要测量的点,calsses_list_matched是最终的类别编号,
- # contours_L_matched是contours_L的子集,是最终左右目匹配成功的轮廓
- points, classes_list_L_matched, contours_L_matched = contours_match_new(contours_L,contours_R, classes_list_L)
-
- #print('points: ', points)
- if len(points)<1:
- logger.info('no match contours')
- #print('no match contours')
- continue
- #计算长度
- #fish_length is like:[185.11, 192.99, 255.01],长度和points一致
- #如果检测到一条鱼大于650mm,会导致fish_length长度小于points
- fish_length = get_length(points)
- if len(fish_length) != len(points):
- logger.warning('fish lenght is more than 650mm')
- continue
- logger.info(str(i)+'fish_length:'+str(fish_length))
- #print('length: ',fish_length)
- fish_weight = get_weight(fish_length)
- current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-5]
- #save image_L with name like '2021-10-10-12-10-11.jpg' and draw contours and fish id.
- #there is a crontab job to delete image, see /home/sencott/delete_history_image.sh
- image_with_contours = save_image(image_L,image_R,contours_L_matched,current_time)
- #cv2.imwrite('masked.jpg',image_with_contours)
- #get fish sepcies
- fish_species = get_species(classes_list_L_matched)
- #put image drawed contours and fish info to queue
- put_image(image_with_contours,fish_species,fish_length,fish_weight,Q_masked_img)
- #construct result save to mysql
- result = construct_result(fish_species,fish_length,fish_weight)
- insert_mysql_new(result,current_time)
- #empty cuda cache every 1000 image
- i += 1
- if i == 1000:
- torch.cuda.empty_cache()
- i = 0
- time.sleep(2)
- logger = logger_init(LOG_LEVEL)
- fasterrcnn_init()
- maskrcnn_init()
- #两个队列,分别存放推送到前端的图像和双目处理的图像
- queue_camera_to_web = Queue(10)
- queue_camera_to_net= Queue(10)
- #用于存放识别后的图像的队列,predict_new线程和webskt_server进程通信使用
- queue_masked_img = Queue(10)
- #存放相机设置参数
- queue_param = Queue(1)
- my = mysql(MYSQL_HOST, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE, MYSQL_PORT)
- my_local = mysql(MYSQL_HOST_LOCAL, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE, MYSQL_PORT)
- if __name__=='__main__':
- process1 = Process(target=img_capture,args=(queue_camera_to_web,queue_camera_to_net,queue_param))
- process1.daemon=True
- process1.start()
- process2 = Process(target=push_to_web,args=(queue_camera_to_web,queue_masked_img,queue_param))
- process2.daemon=True
- process2.start()
- predict_new(queue_camera_to_net,queue_masked_img)
|