当前位置 博文首页 > cumtchw:HRNet训练自己的分类数据
目录
1.搭建环境
1.1创建conda环境
1.2下载HRNet工程
1.3安装依赖
2.准备数据集
3.修改配置文件
4.增加保存pth模型的代码
5.增加代码保存分类名称和索引
6.开始训练
为了防止和服务器上的环境冲突,这里利用conda搭建环境。
conda create -n HRNet_chw python=3.7
conda activate HRNet_chw
git clone https://github.com/HRNet/HRNet-Image-Classification
首先修改工程里面的requirement.txt。其中增加torchvision==0.8.2和torch==1.7.1。另外由于工程里面原有的requirement.txt文件里面某些依赖的版本找不到,所以把一些版本删掉,修改后的requirement.txt如下:
torchvision==0.8.2
torch==1.7.1
EasyDict==1.7
opencv-python
shapely
Cython
scipy
pandas
pyyaml
json_tricks
scikit-image
yacs>=0.1.5
tensorboardX>=1.6
在HRNet的目录中创建目录??imagenet/images/train? 和????imagenet/images/val,然后在imagenet/images/train里面放上相应的图片,其中每个种类一个单独的文件夹。
├─imagenet
├─ images
│? ├─ train
│? │ ?├─ cat
│? │ ?├─ dog
│? ├─ val
│? │ ?├─ cat
│? │ ?├─ dog
可以修改下配置文件中的batchsize。
原始的train.py只保存了模型的参数,并没有保存完整的模型文件,因此在train.py的最下面增加如下代码用于保存模型文件。
torch.save(model.module.state_dict(), final_model_state_file)#原有代码
final_pth_file = os.path.join(final_output_dir, 'HRNet.pth')#增加的代码
print("final_pth_file:", final_pth_file)#增加的代码
torch.save(model.module, final_pth_file)#增加的代码
在train.py中增加如下代码
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
)
#以下为增加的代码,上面几行是原有的代码
#print(train_dataset.classes) #根据分的文件夹的名字来确定的类别
with open("class.txt","w") as f1:
for classname in train_dataset.classes:
f1.write(classname + "\n")
#print(train_dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
with open("classToIndex.txt", "w") as f2:
for key, value in train_dataset.class_to_idx.items():
f2.write(str(key) + " " + str(value) + '\n')
#print(train_dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
得到的class.txt如下
baby
dog
people
得到的classToIndex.txt如下
baby 0
dog 1
people 2
python tools/train.py --cfg experiments/cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
参考文献:??
pytorch学习笔记七:torchvision.datasets.ImageFolder使用详解??https://blog.csdn.net/qq_39507748/article/details/105394808
pytorch中保存的模型文件.pth深入解析? ??https://blog.csdn.net/qq_27825451/article/details/100773473
cs