当前位置 博文首页 > cumtchw:HRNet训练自己的分类数据

    cumtchw:HRNet训练自己的分类数据

    作者:[db:作者] 时间:2021-07-28 15:12

    目录

    1.搭建环境

    1.1创建conda环境

    1.2下载HRNet工程

    1.3安装依赖

    2.准备数据集

    3.修改配置文件

    4.增加保存pth模型的代码

    5.增加代码保存分类名称和索引

    6.开始训练


    1.搭建环境

    1.1创建conda环境

    为了防止和服务器上的环境冲突,这里利用conda搭建环境。

    conda create -n HRNet_chw python=3.7
    conda activate HRNet_chw

    1.2下载HRNet工程

    git clone https://github.com/HRNet/HRNet-Image-Classification

    1.3安装依赖

    首先修改工程里面的requirement.txt。其中增加torchvision==0.8.2和torch==1.7.1。另外由于工程里面原有的requirement.txt文件里面某些依赖的版本找不到,所以把一些版本删掉,修改后的requirement.txt如下:

    torchvision==0.8.2
    torch==1.7.1
    EasyDict==1.7
    opencv-python
    shapely
    Cython
    scipy
    pandas
    pyyaml
    json_tricks
    scikit-image
    yacs>=0.1.5
    tensorboardX>=1.6

    2.准备数据集

    在HRNet的目录中创建目录??imagenet/images/train? 和????imagenet/images/val,然后在imagenet/images/train里面放上相应的图片,其中每个种类一个单独的文件夹。

    ├─imagenet
      ├─ images
      │? ├─ train
      │? │ ?├─ cat
      │? │ ?├─ dog
      │? ├─ val
      │? │ ?├─ cat
      │? │ ?├─ dog

    3.修改配置文件

    可以修改下配置文件中的batchsize。

    4.增加保存pth模型的代码

    原始的train.py只保存了模型的参数,并没有保存完整的模型文件,因此在train.py的最下面增加如下代码用于保存模型文件。

        torch.save(model.module.state_dict(), final_model_state_file)#原有代码
    
        final_pth_file = os.path.join(final_output_dir,  'HRNet.pth')#增加的代码
        print("final_pth_file:", final_pth_file)#增加的代码
        torch.save(model.module, final_pth_file)#增加的代码

    5.增加代码保存分类名称和索引

    在train.py中增加如下代码

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        )
    
    
        #以下为增加的代码,上面几行是原有的代码
        #print(train_dataset.classes)  #根据分的文件夹的名字来确定的类别
        with open("class.txt","w") as f1:
            for classname in train_dataset.classes:
                f1.write(classname + "\n")
    
        #print(train_dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
        with open("classToIndex.txt", "w") as f2:
            for key, value in train_dataset.class_to_idx.items():
                f2.write(str(key) + " " + str(value) + '\n')
    
        #print(train_dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
    
    
    

    得到的class.txt如下

    baby
    dog
    people

    得到的classToIndex.txt如下

    baby 0
    dog 1
    people 2

    6.开始训练

    python tools/train.py --cfg experiments/cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

    参考文献:??

    pytorch学习笔记七:torchvision.datasets.ImageFolder使用详解??https://blog.csdn.net/qq_39507748/article/details/105394808

    pytorch中保存的模型文件.pth深入解析? ??https://blog.csdn.net/qq_27825451/article/details/100773473

    cs