diff --git a/data/1.mp4 b/data/1.mp4 new file mode 100644 index 0000000..4dad434 Binary files /dev/null and b/data/1.mp4 differ diff --git a/data/2.mp4 b/data/2.mp4 new file mode 100644 index 0000000..1bfd9f8 Binary files /dev/null and b/data/2.mp4 differ diff --git a/data/3.mp4 b/data/3.mp4 new file mode 100644 index 0000000..7e7ca69 Binary files /dev/null and b/data/3.mp4 differ diff --git a/data/4.mp4 b/data/4.mp4 new file mode 100644 index 0000000..ae377b3 Binary files /dev/null and b/data/4.mp4 differ diff --git a/src/__main__.py b/src/__main__.py index d5facbf..2ab74d0 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,94 +1,60 @@ -from cv2 import cv2 +import cv2 +CAT_CLASS_ID = 8 PROTO_TXT = "model/MobileNetSSD_deploy.prototxt" CAFFE_MODEL = "model/MobileNetSSD_deploy.caffemodel" DNN = cv2.dnn.readNetFromCaffe(PROTO_TXT, CAFFE_MODEL) -classNames = { - 0: "background", - 1: "aeroplane", - 2: "bicycle", - 3: "bird", - 4: "boat", - 5: "bottle", - 6: "bus", - 7: "car", - 8: "cat", - 9: "chair", - 10: "cow", - 11: "diningtable", - 12: "dog", - 13: "horse", - 14: "motorbike", - 15: "person", - 16: "pottedplant", - 17: "sheep", - 18: "sofa", - 19: "train", - 20: "tvmonitor", -} - - -cap = cv2.VideoCapture(0) +cap = cv2.VideoCapture("data/4.mp4") +tracker = cv2.TrackerKCF_create() +tracking = False while True: ret, frame = cap.read() - # size of image - width = frame.shape[1] - height = frame.shape[0] - # construct a blob from the image - BLOB = cv2.dnn.blobFromImage( - frame, - scalefactor=1 / 127.5, - size=(300, 300), - mean=(127.5, 127.5, 127.5), - swapRB=True, - crop=False, - ) - DNN.setInput(BLOB) - detections = DNN.forward() + if not tracking: + height, width = frame.shape[:2] + + BLOB = cv2.dnn.blobFromImage( + frame, + scalefactor=1 / 127.5, + size=(300, 300), + mean=(127.5, 127.5, 127.5), + swapRB=True, + crop=False, + ) + DNN.setInput(BLOB) + detections = DNN.forward() + + for i in range(detections.shape[2]): + confidence = detections[0, 0, i, 2] + if confidence > 0.3: + class_id = int(detections[0, 0, i, 1]) - for i in range(detections.shape[2]): - confidence = detections[0, 0, i, 2] - if confidence > 0.5: - class_id = int(detections[0, 0, i, 1]) + if class_id != CAT_CLASS_ID: + continue - if class_id != 8: - continue + x_top_left = int(detections[0, 0, i, 3] * width) + y_top_left = int(detections[0, 0, i, 4] * height) + x_bottom_right = int(detections[0, 0, i, 5] * width) + y_bottom_right = int(detections[0, 0, i, 6] * height) - x_top_left = int(detections[0, 0, i, 3] * width) - y_top_left = int(detections[0, 0, i, 4] * height) - x_bottom_right = int(detections[0, 0, i, 5] * width) - y_bottom_right = int(detections[0, 0, i, 6] * height) + bbox = (x_top_left, y_top_left, x_bottom_right - x_top_left, y_bottom_right - y_top_left) - cv2.rectangle( - frame, - (x_top_left, y_top_left), - (x_bottom_right, y_bottom_right), - (0, 255, 0), - ) + tracking = True + tracker.init(frame, bbox) + print("Tracking started") - if class_id in classNames: - label = classNames[class_id] + ": " + str(confidence) - (w, h), t = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) - y_top_left = max(y_top_left, h) - cv2.rectangle( - frame, - (x_top_left, y_top_left - h), - (x_top_left + w, y_top_left + t), - (0, 0, 0), - cv2.FILLED, - ) - cv2.putText( - frame, - label, - (x_top_left, y_top_left), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 0), - ) + else: + success, bbox = tracker.update(frame) + if success: + x, y, w, h = [int(i) for i in bbox] + cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) + else: + tracking = False + tracker = cv2.TrackerKCF_create() + print("Tracking stopped") cv2.namedWindow("frame", cv2.WINDOW_NORMAL) cv2.imshow("frame", frame)