首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于识别和年龄/性别检测的人脸跟踪

基于识别和年龄/性别检测的人脸跟踪
EN

Code Review用户
提问于 2022-01-31 16:59:57
回答 2查看 1.1K关注 0票数 8

该程序加载视频,检测和跟踪视频中的人脸,并检测每个人的估计年龄和性别。目的是收集关于视频中独特人物的数量及其属性的数据。

整体代码结构

视频是逐帧处理的。

主要数据结构为people_info。这是一个字典列表,每个字典都包含一个人的属性。它保存属于个人的轨道It、人脸嵌入以及年龄和性别类别。

每一帧:

  1. 使用YOLOv3对象检测器检测所有人脸。
  2. 使用跟踪算法DeepSORT确定跟踪ID。跟踪ID识别视频中的某一人脸轨迹。如果同一个人不止一次出现在视频中,他们可以有多个音轨in。
  3. 如果存在一个新的跟踪ID,检查该帧中的人脸是否具有足够高的质量,可以通过人脸识别、年龄和性别检测器进行处理。
    • 高质量:使用人脸识别来与先前看到的人进行比较。如果这个人是新来的,用发色器来确定年龄和性别。将此跟踪作为新条目或添加到现有条目添加到person_info。
    • 低质量:不断检查每一帧,直到你得到这张脸的高质量图像。

  4. 输出带注释的视频显示:
    • 包围每一张脸的盒子,追踪一个人
    • 标识一个人的单个外观的“跟踪ID”(特定人在场的帧序列)。
    • 标识唯一人的估计'Person ID‘。如果同一个人出现不止一次,则应该为他们分配相同的ID。
    • 估计年龄和性别。这些都是每n帧重新计算的。它们常常因人看起来略有不同而在不同的帧间发生变化。

所有的模型都是开源的,都是经过预先培训的。

代码

我发布的这段代码最初是从theAIGuysCode/yolov4-4-深度排序复制的,并引用了一些您可以在那里找到的其他模块。我对它进行了修改,增加了识别和年龄/性别检测功能。此脚本从命令行调用,如下所示:

代码语言:javascript
复制
python object_tracker.py --weights ./checkpoints/yolov3-widerface --model yolov3 --video ./data/video/interview.mp4 --output ./outputs/interview.avi --dont_show --face --age_gender

我添加的是--face--age_gender命令行选项。

问题

在对这个项目进行更多的开发之前,我想了解您对如何重构我的代码的意见。我最感兴趣的是代码中属于我自己的部分:代码开头的函数以及主函数中调用它们的位置(在if FLAGS.age_gender中的任何地方)。我会感兴趣的是,如何最好地加入我的新增加与现有的东西。目前,我认为结构是令人困惑的,因为每个面部裁剪都在不同的点上调整大小,不同的边框格式和嵌套的if语句随处可见。我应该在什么结构中存储有关独特人员和跟踪ID的信息?

我想开发代码,以提高年龄/性别和识别的准确性。我想要存储一个特定人的多个框架的信息,以便我可以平均估计的年龄和性别,以提高估计的准确性。同时,在不同的帧中为同一个人存储多个人脸嵌入,这样我就可以实现“投票”来提高人脸识别的匹配。

欢迎对此代码提出任何建议。

代码语言:javascript
复制
import os

# comment out below line to enable tensorflow logging outputs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import time
import tensorflow as tf

physical_devices = tf.config.experimental.list_physical_devices("GPU")
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
from absl import app, flags, logging
from absl.flags import FLAGS
import core.utils as utils
from core.yolov4 import filter_boxes
from tensorflow.python.saved_model import tag_constants
from core.config import cfg
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

# deep sort imports
from deep_sort import preprocessing, nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
from enum import Enum
from imutils import paths
import face_recognition
import pickle
from PIL import Image, ImageDraw
import dlib


flags.DEFINE_string("framework", "tf", "(tf, tflite, trt")
flags.DEFINE_string("weights", "./checkpoints/yolov4-416", "path to weights file")
flags.DEFINE_integer("size", 416, "resize images to")
flags.DEFINE_boolean("tiny", False, "yolo or yolo-tiny")
flags.DEFINE_string("model", "yolov4", "yolov3 or yolov4")
flags.DEFINE_string(
    "video", "./data/video/test.mp4", "path to input video or set to 0 for webcam"
)
flags.DEFINE_string("output", None, "path to output video")
flags.DEFINE_string(
    "output_format", "XVID", "codec used in VideoWriter when saving video to file"
)
flags.DEFINE_float("iou", 0.45, "iou threshold")
flags.DEFINE_float("score", 0.95, "score threshold")
flags.DEFINE_boolean("dont_show", False, "dont show video output")
flags.DEFINE_boolean("info", False, "show detailed info of tracked objects")
flags.DEFINE_boolean("count", False, "count objects being tracked on screen")
flags.DEFINE_boolean("face", False, "using yoloface")
flags.DEFINE_boolean("age_gender", False, "detecting age and gender")

# age and gender models accept a face crop image that is 224x224 pixels.
MIN_FACE_SIZE: int = 224

# the size of the bounding box needs to be expanded to be input to age/gender models
BBOX_SCALING = 1.9


class AgeCategories(Enum):
    Child = slice(None, 13)
    GenZ = slice(13, 23)
    Millennial = slice(23, 30)
    GenX = slice(30, 55)
    Boomer = slice(55, None)


AgeCategoryFromIndex = {i: cat for i, cat in enumerate(AgeCategories)}


def restrict_bbox(bbox, w, h):
    """for tlbr, edits bbox so it is within the bounds of the frame. w and h are frame width and height"""
    if len(bbox) is 4:
        return [
            min(max(bbox[0], 0), w),
            min(max(bbox[1], 0), h),
            min(max(bbox[2], 0), w),
            min(max(bbox[3], 0), h),
        ]
    else:
        logger.info("bbox is not valid")


def expand_bbox_square(bbox, frame_width, frame_height):
    """for tlbr, expands the bbox so it is a square, also scales to cover the full face. From trial and error I found that scaling of 1.9 works best."""
    height = bbox[3] - bbox[1]
    width = bbox[2] - bbox[0]

    # making bbox a square for processing
    square_bbox_width = max(width, height) * BBOX_SCALING
    x_centre = (bbox[0] + bbox[2]) / 2
    y_centre = (bbox[1] + bbox[3]) / 2

    y_1_square = y_centre - (square_bbox_width / 2)
    x_1_square = x_centre - (square_bbox_width / 2)
    y_2_square = y_centre + (square_bbox_width / 2)
    x_2_square = x_centre + (square_bbox_width / 2)

    square_bbox = restrict_bbox(
        [x_1_square, y_1_square, x_2_square, y_2_square], frame_width, frame_height
    )
    centre = [x_centre, y_centre]
    return square_bbox, centre


def get_age_gender(face_crop, age_model, gender_model):
    """detects age and gender of a face crop image. The provided image must be square and > 224 pixels wide."""
    assert face_crop.shape[0] == face_crop.shape[1], "face crop is not square"
    assert (
        face_crop.shape[0] >= MIN_FACE_SIZE
    ), f"Too little - expected 224; got {face_crop.shape[0]}"
    detected_face = cv2.resize(
        face_crop, (MIN_FACE_SIZE, MIN_FACE_SIZE), interpolation=cv2.INTER_LINEAR
    )  # (224, 224, 3) now
    img_blob = cv2.dnn.blobFromImage(
        detected_face
    )  # img_blob shape is (1, 3, 224, 224)

    gender_model.setInput(img_blob)
    gender_class = gender_model.forward()[0]
    gender = "Woman " if np.argmax(gender_class) == 0 else "Man"

    age_model.setInput(img_blob)
    age_dist = age_model.forward()[0]
    slot_ages = [sum(age_dist[cat.value]) for cat in AgeCategories]
    age_category_name = AgeCategoryFromIndex[np.argmax(slot_ages)].name

    return age_category_name, gender


def is_face_image_good(face_crop, track, save_faces=True, use_landmarks=False):
    """checks if the face crop image is high quality so it can be processed by
    recognition/age/gender. Checks face is large enough (so high resolution) and
    square (age/gender models accept 224x224 face images). Checks if the face isn't
    side facing by trying to generate face landmarks. Side facing faces cannot generate landmarks."""
    width = face_crop.shape[0]
    height = face_crop.shape[1]

    if (width > MIN_FACE_SIZE) and (width == height):
        # check whether face is side facing by seeing if landmarks can be generated
        # face_recogniton accepts rgb ordering

        if use_landmarks:
            face_landmarks_list = face_recognition.face_landmarks(face_crop)

            if len(face_landmarks_list) != 0:

                if save_faces:
                    # need to convert to bgr ordering to use cv2 imwrite
                    cv2.imwrite(
                        "outputs/face_images/original_" + str(track.track_id) + ".jpg",
                        cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB),
                    )

                    # pillow uses rgb ordering
                    pil_image = Image.fromarray(face_crop)
                    d = ImageDraw.Draw(pil_image)
                    for face_landmarks in face_landmarks_list:

                        # Let's trace out each facial feature in the image with a line!
                        for facial_feature in face_landmarks.keys():
                            d.line(face_landmarks[facial_feature], width=5)
                    pil_image.save(
                        "outputs/face_images/"
                        + "annotated_"
                        + str(track.track_id)
                        + ".jpg"
                    )
                return True
            else:
                print("Face is the right size but landmarks could not be generated.")
                return False
        else:
            cv2.imwrite(
                "outputs/face_images/original_" + str(track.track_id) + ".jpg",
                cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB),
            )
            return True
    else:
        print("Face not the right size: " + str(face_crop.shape))
        return False


def match_new_face(face_crop, people_info, track, age_model, gender_model):
    """Find out whether we have seen this person before by comparing against people_info. We generate an embedding for the new
    face and see if there are any matches with earlier recorded people."""
    # resizing face to be smaller to try and avoid memory error - a face that was approx 700x700 was causing a memory allocation error.
    face_crop = cv2.resize(
        face_crop, (MIN_FACE_SIZE, MIN_FACE_SIZE), interpolation=cv2.INTER_LINEAR
    )
    face_embeddings = face_recognition.face_encodings(face_crop)

    if len(face_embeddings) > 0:
        face_embedding = face_embeddings[0]
        # compare the new face with all other people seen before
        for i, person in enumerate(people_info):
            past_face_embedding = person["face_embedding"]
            isSame = face_recognition.compare_faces(
                [past_face_embedding], face_embedding
            )[0]
            if isSame:
                print(
                    "person in track "
                    + str(track.track_id)
                    + " has been seen in a previous tracklet group: "
                    + str(person["track_ids"])
                    + " and the person ID is: "
                    + str(i)
                )
                person_id = i
                # choose to update age/gender
                age_category_name, gender = get_age_gender(
                    face_crop, age_model, gender_model
                )
                people_info[i]["track_ids"].append(track.track_id)
                people_info[i]["gender"] = gender
                people_info[i]["age_category_name"] = age_category_name
                break
        # if no match found, create a new person id.
        else:
            person_id = len(people_info)
            print(
                "person in track "
                + str(track.track_id)
                + " has never been seen before - assign new person ID: "
                + str(person_id)
            )
            # this is a totally new person, so we want to calc their age and gender

            age_category_name, gender = get_age_gender(
                face_crop, age_model, gender_model
            )
            people_info.append(
                {
                    "track_ids": [track.track_id],
                    "age_category_name": age_category_name,
                    "gender": gender,
                    "face_embedding": face_embedding,
                }
            )
    else:
        person_id = "unknown"
        age_category_name = "unknown"
        gender = "unknown"
        print("no face detected for recognition - discard")
    return age_category_name, gender, person_id


def main(_argv):
    # Definition of the parameters
    max_cosine_distance = 0.4
    nn_budget = None
    nms_max_overlap = 1.0

    # initialize deep sort
    model_filename = "model_data/mars-small128.pb"
    encoder = gdet.create_box_encoder(model_filename, batch_size=1)
    # calculate cosine distance metric
    metric = nn_matching.NearestNeighborDistanceMetric(
        "cosine", max_cosine_distance, nn_budget
    )
    # initialize tracker
    tracker = Tracker(metric)

    # load configuration for object detector
    config = ConfigProto()
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)
    STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
    input_size = FLAGS.size
    video_path = FLAGS.video

    # load tflite model if flag is set
    if FLAGS.framework == "tflite":
        interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        print(input_details)
        print(output_details)
    # otherwise load standard tensorflow saved model
    else:
        saved_model_loaded = tf.saved_model.load(
            FLAGS.weights, tags=[tag_constants.SERVING]
        )
        infer = saved_model_loaded.signatures["serving_default"]
    if FLAGS.age_gender:
        gender_model = cv2.dnn.readNetFromCaffe(
            "model_data/gender.prototxt", "model_data/gender.caffemodel"
        )
        age_model = cv2.dnn.readNetFromCaffe(
            "model_data/age.prototxt", "model_data/dex_chalearn_iccv2015.caffemodel"
        )
        # make list that will hold age/gender predictions and face embedding.
        # each element of the list will represent a unique person (not a tracket). the element is a dict with fields holding age, gender, embedding and track ids

        people_info = []
    # begin video capture
    try:
        vid = cv2.VideoCapture(int(video_path))
    except:
        vid = cv2.VideoCapture(video_path)
    out = None
    frame_width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
    # get video ready to save locally if flag is set
    if FLAGS.output:
        # by default VideoCapture returns float instead of int
        fps = int(vid.get(cv2.CAP_PROP_FPS))
        codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
        out = cv2.VideoWriter(FLAGS.output, codec, fps, (frame_width, frame_height))
    frame_num = 0
    # while video is running
    while True:
        return_value, frame = vid.read()
        if return_value:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(frame)
        else:
            print("Video has ended or failed, try a different video format!")

            break
        frame_num += 1
        # print('Frame #: ', frame_num)
        frame_size = frame.shape[:2]
        image_data = cv2.resize(frame, (input_size, input_size))
        image_data = image_data / 255.0
        image_data = image_data[np.newaxis, ...].astype(np.float32)
        start_time = time.time()

        # run detections on tflite if flag is set
        if FLAGS.framework == "tflite":
            interpreter.set_tensor(input_details[0]["index"], image_data)
            interpreter.invoke()
            pred = [
                interpreter.get_tensor(output_details[i]["index"])
                for i in range(len(output_details))
            ]
            # run detections using yolov3 if flag is set
            if FLAGS.model == "yolov3" and FLAGS.tiny == True:
                boxes, pred_conf = filter_boxes(
                    pred[1],
                    pred[0],
                    score_threshold=0.25,
                    input_shape=tf.constant([input_size, input_size]),
                )
            else:
                boxes, pred_conf = filter_boxes(
                    pred[0],
                    pred[1],
                    score_threshold=0.25,
                    input_shape=tf.constant([input_size, input_size]),
                )
        else:
            batch_data = tf.constant(image_data)
            pred_bbox = infer(batch_data)
            for key, value in pred_bbox.items():
                boxes = value[:, :, 0:4]
                pred_conf = value[:, :, 4:]
        (
            boxes,
            scores,
            classes,
            valid_detections,
        ) = tf.image.combined_non_max_suppression(
            boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
            scores=tf.reshape(
                pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])
            ),
            max_output_size_per_class=50,
            max_total_size=50,
            iou_threshold=FLAGS.iou,
            score_threshold=FLAGS.score,
        )

        # convert data to numpy arrays and slice out unused elements
        num_objects = valid_detections.numpy()[0]
        bboxes = boxes.numpy()[0]
        bboxes = bboxes[0 : int(num_objects)]
        scores = scores.numpy()[0]
        scores = scores[0 : int(num_objects)]
        classes = classes.numpy()[0]
        classes = classes[0 : int(num_objects)]

        # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height
        original_h, original_w, _ = frame.shape
        bboxes = utils.format_boxes(bboxes, original_h, original_w)

        # store all predictions in one parameter for simplicity when calling functions
        pred_bbox = [bboxes, scores, classes, num_objects]

        # read in all class names from config
        class_names = utils.read_class_names(cfg.YOLO.CLASSES)

        # by default allow all classes in .names file
        allowed_classes = list(class_names.values())

        # loop through objects and use class index to get class name, allow only classes in allowed_classes list
        names = []
        deleted_indx = []
        for i in range(num_objects):
            class_indx = int(classes[i])
            class_name = class_names[class_indx]
            if class_name not in allowed_classes:
                deleted_indx.append(i)
            else:
                names.append(class_name)
        names = np.array(names)
        count = len(names)
        if FLAGS.count:
            cv2.putText(
                frame,
                "Objects being tracked: {}".format(count),
                (5, 35),
                cv2.FONT_HERSHEY_COMPLEX_SMALL,
                2,
                (0, 255, 0),
                2,
            )
            print("Objects being tracked: {}".format(count))
        # delete detections that are not in allowed_classes
        bboxes = np.delete(bboxes, deleted_indx, axis=0)
        scores = np.delete(scores, deleted_indx, axis=0)

        # encode yolo detections and feed to tracker
        features = encoder(frame, bboxes)
        detections = [
            Detection(bbox, score, class_name, feature)
            for bbox, score, class_name, feature in zip(bboxes, scores, names, features)
        ]

        # initialize color map
        cmap = plt.get_cmap("tab20b")
        colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]

        # run non-maxima supression
        boxs = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        classes = np.array([d.class_name for d in detections])
        indices = preprocessing.non_max_suppression(
            boxs, classes, nms_max_overlap, scores
        )
        detections = [detections[i] for i in indices]

        # Call the tracker
        tracker.predict()
        tracker.update(detections)

        # update tracks
        for track in tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            bbox = track.to_tlbr()

            class_name = track.get_class()

            if FLAGS.age_gender:
                # we want the bbox to be larger and square (if not on the edge of the frame) so it is in the right format for age/gender detection.
                bbox, bbox_centre = expand_bbox_square(bbox, frame_width, frame_height)
                # face_crop is in rgb ordering
                face_crop = frame[
                    int(bbox[1]) : int(bbox[3]), int(bbox[0]) : int(bbox[2])
                ]

                if any(track.track_id in person["track_ids"] for person in people_info):
                    person_id = np.argmax(
                        [
                            track.track_id in person["track_ids"]
                            for person in people_info
                        ]
                    )
                    # decide to update age and gender prediction periodically, e.g. every 20 frames attempt to update

                    if frame_num % 10 == 0:
                        if is_face_image_good(face_crop, track):
                            age_category_name, gender = get_age_gender(
                                face_crop, age_model, gender_model
                            )
                            people_info[person_id][
                                "age_category_name"
                            ] = age_category_name
                            people_info[person_id]["gender"] = gender
                            print("updated age and gender")
                    age_category_name = people_info[person_id]["age_category_name"]
                    gender = people_info[person_id]["gender"]
                    print(
                        "this track ID "
                        + str(track.track_id)
                        + " has been recorded before - person ID: "
                        + str(person_id)
                    )
                else:
                    print(
                        "new tracklet ID "
                        + str(track.track_id)
                        + " - checking whether face image is high quality"
                    )
                    # this is a new tracklet
                    if is_face_image_good(face_crop, track):

                        print("face image is high quality")

                        # need to check if we've seen this person before with face recognition

                        age_category_name, gender, person_id = match_new_face(
                            face_crop, people_info, track, age_model, gender_model
                        )
                    else:
                        # discard
                        person_id = "unknown"
                        age_category_name = "unknown"
                        gender = "unknown"
                        print(
                            "face image is of low quality - discard and keep checking for a high quality face image"
                        )
            # draw bbox on screen
            color = colors[int(track.track_id) % len(colors)]
            color = [i * 255 for i in color]
            cv2.rectangle(
                frame,
                (int(bbox[0]), int(bbox[1])),
                (int(bbox[2]), int(bbox[3])),
                color,
                2,
            )
            cv2.rectangle(
                frame,
                (int(bbox[0]), int(bbox[1] - 30)),
                (
                    int(bbox[0]) + (len(class_name) + len(str(track.track_id))) * 17,
                    int(bbox[1]),
                ),
                color,
                -1,
            )
            cv2.putText(
                frame,
                class_name + " Tracker ID - " + str(track.track_id),
                (int(bbox[0]), int(bbox[1] + 20)),
                0,
                0.75,
                (255, 255, 255),
                2,
            )
            if FLAGS.age_gender:
                cv2.putText(
                    frame,
                    "Person ID - "
                    + str(person_id)
                    + ", "
                    + age_category_name
                    + ", "
                    + gender,
                    (int(bbox[0]), int(bbox[3] - 10)),
                    0,
                    0.75,
                    (255, 255, 255),
                    2,
                )
            # if enable info flag then print details about each track
            if FLAGS.info:
                print(
                    "Tracker ID: {}, Class: {},  BBox Coords (xmin, ymin, xmax, ymax): {}".format(
                        str(track.track_id),
                        class_name,
                        (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])),
                    )
                )
        # calculate frames per second of running detections
        fps = 1.0 / (time.time() - start_time)
        print("FPS: %.2f" % fps)
        result = np.asarray(frame)
        result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        if not FLAGS.dont_show:
            cv2.imshow("Output Video", result)
        # if output flag is set, save video file
        if FLAGS.output:
            out.write(result)
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break
    cv2.destroyAllWindows()


if __name__ == "__main__":
    try:
        app.run(main)
    except SystemExit:
        pass
EN

回答 2

Code Review用户

发布于 2022-02-01 01:03:17

我对简历不太熟悉,所以我的评论不会解决你所有的担忧。尽管如此:

控件块与嵌套

使代码有点难以理解的是控制块,您有相当多的iffor。您确实应该添加行间距,以便更好地分离控制块,使代码更易读懂。

你也有if,S和fors的嵌套。你知道Python对缩进很敏感,所以很容易产生出有缺陷的控制结构,这种结构很可能在很长一段时间内都没有人注意到。

例如,在第196行的函数match_new_face中,您有:

代码语言:javascript
复制
if len(face_embeddings) > 0:

以及相距很远的第244行的匹配else

代码语言:javascript
复制
else:
    person_id = "unknown"
    age_category_name = "unknown"
    gender = "unknown"
    print("no face detected for recognition - discard")
return age_category_name, gender, person_id

您应该将逻辑封装到专用函数中。所以你的街区最终应该看起来更像这样:

代码语言:javascript
复制
if len(face_embeddings) > 0:
    compare_faces(params)
else:
    create_new_person(params)

不确定参数,但你明白了。创建专用函数将使代码更易于阅读,也更易于调试,并将便利单元测试。

一个函数,一个目的

你的函数get_age_gender做两件非常不同的事情--决定年龄和性别。每个目的应有一个功能。只将现有代码分割成两部分将很容易,也是一项轻微的改进。代码将变得更容易维护。

但是这样做可能会导致代码重复,这是我们通常希望避免的。例如,您检查图片是否为正方形。然后,答案是创建一个函数来检查它,并可能在开始处理图像之前执行其他验证步骤。验证应该与处理分开。

还有很多初始化代码应该从主块中移出。像设备校准之类的东西可以有它自己的功能。

关键是要消除混乱,否则维护代码将很困难,而且随着时间的推移,增加更多的功能将变得费时。

同样,您的函数is_face_image_good声明如下:

代码语言:javascript
复制
def is_face_image_good(face_crop, track, save_faces=True, use_landmarks=False):

如果我们查看代码,我们可以看到,如果save_faces是True,那么您将继续处理更完全无关的内容:在进行一些处理之后,将图像写入文件。把那部分移到另一个函数。结果,您将看到函数is_face_image_good变得更加清晰,并且它的整体可以适合您的屏幕。

顾名思义,你的函数应该只做一件事:告诉你的脸是好的,然后返回一个是或否的答案,这样就可以了。

改进您的数据模型

我想你错过了上课的人。因为您的应用程序处理的是人员及其特征/属性,例如年龄(范围)、性别,因此有理由使用该软件。people_info是一个拙劣的替代品,而不是毕达通。

您已经开始使用类和Enum,这是有意义的。您可以更好地利用它们,例如,对于年龄类别,然后我认为您可以放弃AgeCategoryFromIndex。如果实际值不重要,可以使用auto()代替枚举中的普通整数。因此,函数match_new_face可以返回Person对象,而不是多个变量:age_category_name, gender, person_id

提示:有时您必须有返回两个值的函数;然后使用命名的元组是个好主意-- 教程

这不是最好的例子,我也不明白你所有的代码,但也许这段代码段可以作为一个开始。考虑使用数据集进行简洁性(需要PythonPython3.7),但如果您对经典类感到更舒服,则这不是义务。奖励:此实现负责为您增加ID。我没有改变AgeCategories,因为我还没有弄清楚年龄分类是如何发生的。

代码语言:javascript
复制
import itertools

from enum import Enum, auto
from dataclasses import dataclass, field


class Gender(Enum):
    Man = auto()
    Woman = auto()
    Undefined = auto()


class AgeCategories(Enum):
    Child = slice(None, 13)
    GenZ = slice(13, 23)
    Millennial = slice(23, 30)
    GenX = slice(30, 55)
    Boomer = slice(55, None)


@dataclass
class Person:
    name: str
    # ID shall be incremented automatically
    id: int = field(init=False)
    gender: Gender = Gender.Man
    age_category: AgeCategories = AgeCategories.Millennial

    # starts at 0 by default
    id_init = itertools.count(start=1)

    def __post_init__(self):
        self.id = next(self.id_init)


person = Person(name="Sarah", gender=Gender.Woman)
print(person)
# this will output:
Person(name='Sarah', id=1, gender=<Gender.Woman: 2>, age_category=<AgeCategories.Millennial: slice(23, 30, None)>)

# this will output: 2
print(person.gender.value)
# this will output: Gender.Woman
print(person.gender)
# this will output: Woman
print(person.gender.name)

当然,类可以位于单独的Python文件中,您可以在应用程序中导入该文件。这允许重用,并保持您的代码库简短和精益。

我不明白这是怎么回事

代码语言:javascript
复制
# begin video capture
try:
    vid = cv2.VideoCapture(int(video_path))
except:
    vid = cv2.VideoCapture(video_path)

也许您应该更明确地了解您预期的特定异常,而不是仅仅捕捉到任何异常(可能会发生很多事情)。但在我看来,你似乎是在试图处理一开始就不应该发生的事情,这是一种解决办法,而不是一种解决办法。

票数 7
EN

Code Review用户

发布于 2022-01-31 18:39:47

写辅助函数

这是一个非常有趣的项目,您的代码的主要问题是没有辅助函数的单块main函数。

例如:

代码语言:javascript
复制
    frame_size = frame.shape[:2]
    image_data = cv2.resize(frame, (input_size, input_size))
    image_data = image_data / 255.0
    image_data = image_data[np.newaxis, ...].astype(np.float32)

可以是一种功能:

代码语言:javascript
复制
def resize_and_convert_frame(frame):
    # ...

为了使代码更容易阅读和测试,这只是一个例子,一般来说,当您想要对数据做一些合理的自我包含的工作时,您应该尝试将它抽象成一个函数。

此外,这个代码块可以合理地成为它自己的功能:

代码语言:javascript
复制
    # convert data to numpy arrays and slice out unused elements
    num_objects = valid_detections.numpy()[0]
    bboxes = boxes.numpy()[0]
    bboxes = bboxes[0 : int(num_objects)]
    scores = scores.numpy()[0]
    scores = scores[0 : int(num_objects)]
    classes = classes.numpy()[0]
    classes = classes[0 : int(num_objects)]

    # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height
    original_h, original_w, _ = frame.shape
    bboxes = utils.format_boxes(bboxes, original_h, original_w)

    # store all predictions in one parameter for simplicity when calling functions
    pred_bbox = [bboxes, scores, classes, num_objects]

    # read in all class names from config
    class_names = utils.read_class_names(cfg.YOLO.CLASSES)

    # by default allow all classes in .names file
    allowed_classes = list(class_names.values())

而这个:

代码语言:javascript
复制
    names = []
    deleted_indx = []
    for i in range(num_objects):
        class_indx = int(classes[i])
        class_name = class_names[class_indx]
        if class_name not in allowed_classes:
            deleted_indx.append(i)
        else:
            names.append(class_name)
    names = np.array(names)
    count = len(names)

现在有一些不太重要的问题:

减少条件分支

中的代码重复

您应该避免过多的代码重复,例如:

代码语言:javascript
复制
        if FLAGS.model == "yolov3" and FLAGS.tiny == True:
            boxes, pred_conf = filter_boxes(
                pred[1],
                pred[0],
                score_threshold=0.25,
                input_shape=tf.constant([input_size, input_size]),
            )
        else:
            boxes, pred_conf = filter_boxes(
                pred[0],
                pred[1],
                score_threshold=0.25,
                input_shape=tf.constant([input_size, input_size]),
            )

会变成这样:

代码语言:javascript
复制
        if FLAGS.model == "yolov3" and FLAGS.tiny == True:
            first, last = 1, 0
        else:
            first, last = 0, 1
        boxes, pred_conf = filter_boxes(
                pred[first],
                pred[last],
                score_threshold=0.25,
                input_shape=tf.constant([input_size, input_size]),
            )

这样就可以一目了然地了解由于条件(交换了01索引)而发生了哪些更改,以及哪些没有更改(整个函数调用)。

更好的字符串格式

如果您的Python版本支持它们,则应该使用f字符串:

代码语言:javascript
复制
        if FLAGS.age_gender:
            cv2.putText(
                frame,
                "Person ID - "
                + str(person_id)
                + ", "
                + age_category_name
                + ", "
                + gender,
                (int(bbox[0]), int(bbox[3] - 10)),
                0,
                0.75,
                (255, 255, 255),
                2,
            )

变成:

代码语言:javascript
复制
        if FLAGS.age_gender:
            cv2.putText(
                frame,
                f"Person ID - {person_id}, {age_category_name}, {gender}",
                (int(bbox[0]), int(bbox[3] - 10)),
                0,
                0.75,
                (255, 255, 255),
                2,
            )

这显然比以前更加可读性强。否则,.format()也是一个不错的选择。

票数 5
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/273583

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档