当前位置 博文首页 > python菜鸟:【深度学习入门案例】几行代码实现任何动物种类识别

    python菜鸟:【深度学习入门案例】几行代码实现任何动物种类识别

    作者:[db:作者] 时间:2021-08-30 22:27

    一、定义待预测数据

    数据集:
    在这里插入图片描述
    代码:

    # 待预测图片
    test_img_path = ['./img/img.png', './img/img_1.png','./img/img_2.png','./img/img_3.png','./img/img_4.png']
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    # 展示其中大狮子图片
    img1 = mpimg.imread(test_img_path[0])
    
    plt.figure(figsize=(10, 10))
    plt.imshow(img1)
    
    plt.axis('off')
    plt.show()
    

    返回:
    在这里插入图片描述
    若是待预测图片存放在一个文件中,如左侧文件夹所示的test.txt。每一行是待预测图片的存放路径。

    with open('tu.txt', 'r') as f:
        try:
            test_img_path=[]
            for line in f:
                test_img_path.append(line.strip())
        except:
            print('数据加载失败')
    print(test_img_path)
    

    返回:
    在这里插入图片描述

    二、 加载预训练模型

    PaddleHub提供了两种动物识别模型:

    • resnet50_vd_animals: ResNet-vd 其实就是 ResNet-D,是ResNet 原始结构的变种,可用于图像分类和特征提取。该 PaddleHub Module 采用百度自建动物数据集训练得到,支持7978种动物的分类识别。
    • mobilenet_v2_animals: MobileNet V2 是一个轻量化的卷积神经网络,它在 MobileNet 的基础上,做了 Inverted Residuals 和 Linear bottlenecks 这两大改进。该 PaddleHub Module 是在百度自建动物数据集上训练得到的,可用于图像分类和特征提取,当前已支持7978种动物的分类识别。

    代码:

    import paddlehub as hub
    module = hub.Module(name="resnet50_vd_animals")
    # module = hub.Module(name="mobilenet_v2_animals")
    

    三、预测

    import cv2
    np_images =[cv2.imread(image_path) for image_path in test_img_path]
    
    results = module.classification(images=np_images)
    
    for result in results:
        print(result)
    

    返回:
    在这里插入图片描述

    四.完整源码

    代码如下:

    # coding=gbk
    """
    作者:川川
    @时间  : 2021/8/29 23:50
    群:970353786
    """
    # 待预测图片
    # test_img_path = ['./img/img.png', './img/img_1.png','./img/img_2.png','./img/img_3.png','./img/img_4.png']
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    # 展示其中大狮子图片
    # img1 = mpimg.imread(test_img_path[0])
    #
    # plt.figure(figsize=(10, 10))
    # plt.imshow(img1)
    #
    # plt.axis('off')
    # plt.show()
    
    
    with open('tu.txt', 'r') as f:
        try:
            test_img_path=[]
            for line in f:
                test_img_path.append(line.strip())
        except:
            print('数据加载失败')
    # print(test_img_path)
    
    import paddlehub as hub
    module = hub.Module(name="resnet50_vd_animals")
    # module = hub.Module(name="mobilenet_v2_animals")
    
    import cv2
    np_images =[cv2.imread(image_path) for image_path in test_img_path]
    
    results = module.classification(images=np_images)
    
    for result in results:
        print(result)
    

    文件架构:
    在这里插入图片描述

    cs