当前位置 博文首页 > python菜鸟:【深度学习案例】二十行代码批量检测戴口罩

    python菜鸟:【深度学习案例】二十行代码批量检测戴口罩

    作者:[db:作者] 时间:2021-08-30 22:27

    一、定义待预测数据

    # 待预测图片
    test_img_path = ["./img.png"]
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    img = mpimg.imread(test_img_path[0])
    
    # 展示待预测图片
    plt.figure(figsize=(10,10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    

    返回:
    在这里插入图片描述

    若是待预测图片存放在一个文件中,如左侧文件夹所示的test.txt。每一行是待预测图片的存放路径。
    代码:

    with open('mask.txt', 'r') as f:
        try:
            test_img_path=[]
            for line in f:
                test_img_path.append(line.strip())
        except:
            print('图片加载失败')
    print(test_img_path)
    

    返回:
    在这里插入图片描述

    二、加载预训练模型

    PaddleHub口罩检测提供了两种预训练模型,pyramidbox_lite_mobile_mask和pyramidbox_lite_server_mask。二者均是基于2018年百度发表于计算机视觉顶级会议ECCV 2018的论文PyramidBox而研发的轻量级模型,模型基于主干网络FaceBoxes,对于光照、口罩遮挡、表情变化、尺度变化等常见问题具有很强的鲁棒性。不同点在于,pyramidbox_lite_mobile_mask是针对于移动端优化过的模型,适合部署于移动端或者边缘检测等算力受限的设备上。
    代码:

    import paddlehub as hub
    
    module = hub.Module(name="pyramidbox_lite_mobile_mask")
    # module = hub.Module(name="pyramidbox_lite_server_mask")
    

    三、预测

    PaddleHub对于支持一键预测的module,可以调用module的相应预测API,完成预测功能。

    # 口罩检测预测
    visualization=True #将预测结果保存图片可视化
    output_dir='detection_result' #预测结果图片保存在当前运行路径下detection_result文件夹下
    results = module.face_detection(images=imgs, use_multi_scale=True, shrink=0.6, visualization=True, output_dir='detection_result/test.jpg')
    for result in results:
        print(result)
    
    # 预测结果展示
    import matplotlib.image as im
    import matplotlib.pyplot as plt
    import os
    
    # 需要读取的路径
    path_name = r'./detection_result'
    
    for item in os.listdir(path=path_name):
        img = im.imread(os.path.join(path_name, item))
        plt.imshow(img)
        plt.show()
    

    返回如下:
    在这里插入图片描述
    其中,label有’MASK’和’NO MASK’两种选择:'MASK’表示戴了口罩,'NO MASK表示没有佩戴口罩。‘left’/‘rigth’/‘top’/'bottom’表示口罩在图片当中的位置。'confidence’表示预测为佩戴口罩’MASK’或者不佩戴口罩’NO MASK’的概率大小。同时,作为一项完善的开源工作,除了本地推断以外,PaddleHub还支持将该预训练模型部署到服务器或移动设备中。

    四.完整源码

    需要文件也可以左侧联系我,当然我也是百度随便找的。

    # coding=gbk
    """
    作者:川川
    @时间  : 2021/8/30 0:14
    群:970353786
    """
    # 待预测图片
    # test_img_path = ["./img.png"]
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    # img = mpimg.imread(test_img_path[0])
    
    # 展示待预测图片
    # plt.figure(figsize=(10,10))
    # plt.imshow(img)
    # plt.axis('off')
    # plt.show()
    
    with open('mask.txt', 'r') as f:
        try:
            test_img_path=[]
            for line in f:
                test_img_path.append(line.strip())
        except:
            print('图片加载失败')
    print(test_img_path)
    # import os
    
    import cv2
    
    # imgs =[cv2.imread(image_path) for image_path in test_img_path]
    imgs=[cv2.imread(test_img_path[0])]
    # for i in imgs:
    #加载模块
    import paddlehub as hub
    
    module = hub.Module(name="pyramidbox_lite_mobile_mask")
    # module = hub.Module(name="pyramidbox_lite_server_mask")
    # 口罩检测预测
    visualization=True #将预测结果保存图片可视化
    output_dir='detection_result' #预测结果图片保存在当前运行路径下detection_result文件夹下
    results = module.face_detection(images=imgs, use_multi_scale=True, shrink=0.6, visualization=True, output_dir='detection_result')
    for result in results:
        print(result)
    
    # 预测结果展示
    import matplotlib.image as im
    import matplotlib.pyplot as plt
    import os
    
    # 需要读取的路径
    path_name = r'./detection_result'
    
    for item in os.listdir(path=path_name):
        img = im.imread(os.path.join(path_name, item))
        plt.imshow(img)
        plt.show()
    

    如果你想放在服务器上:
    执行如下命令启动模型:

    hub serving start -m pyramidbox_lite_server_mask -p 8866
    

    代码为:

    # coding: utf8
    import requests
    import json
    import base64
    import os
    
    # 指定要检测的图片并生成列表[("image", img_1), ("image", img_2), ... ]
    file_list = ["test.jpg"]
    files = [("image", (open(item, "rb"))) for item in file_list]
    
    # 指定检测方法为pyramidbox_lite_server_mask并发送post请求
    url = "http://127.0.0.1:8866/predict/image/pyramidbox_lite_server_mask"
    r = requests.post(url=url, files=files, data={"visual_result": "True"})
    
    results = eval(r.json()["results"])
    
    # 保存检测生成的图片到output文件夹,打印模型输出结果
    if not os.path.exists("output"):
        os.mkdir("output")
    for item in results:
        with open(os.path.join("output", item["path"]), "wb") as fp:
            fp.write(base64.b64decode(item["base64"].split(',')[-1]))
            item.pop("base64")
    print(json.dumps(results, indent=4, ensure_ascii=False))
    
    cs