参考 https://github.com/tensorflow/models/tree/master/slim
使用TensorFlow-Slim进行图像分类
准备
安装TensorFlow
参考 https://www.tensorflow.org/install/
如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
下载TF-slim图像模型库
cd $WORKSPACE git clone https://github.com/tensorflow/models/
准备数据
有不少公开数据集,这里以官网提供的Flowers为例。
官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。
cd $WORKSPACE/data wget http://download.tensorflow.org/example_images/flower_photos.tgz tar zxf flower_photos.tgz
数据集文件夹结构如下:
flower_photos ├── daisy │ ├── 100080576_f52e8ee070_n.jpg │ └── ... ├── dandelion ├── LICENSE.txt ├── roses ├── sunflowers └── tulips
由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。
Python代码:
import os class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4} data_dir = 'flower_photos/' output_path = 'list.txt' fd = open(output_path, 'w') for class_name in class_names_to_ids.keys(): images_list = os.listdir(data_dir + class_name) for image_name in images_list: fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name])) fd.close()
为了方便后期查看label标签,也可以定义labels.txt:
daisy dandelion roses sunflowers tulips
随机生成训练集与验证集:
Python代码:
import random _NUM_VALIDATION = 350 _RANDOM_SEED = 0 list_path = 'list.txt' train_list_path = 'list_train.txt' val_list_path = 'list_val.txt' fd = open(list_path) lines = fd.readlines() fd.close() random.seed(_RANDOM_SEED) random.shuffle(lines) fd = open(train_list_path, 'w') for line in lines[_NUM_VALIDATION:]: fd.write(line) fd.close() fd = open(val_list_path, 'w') for line in lines[:_NUM_VALIDATION]: fd.write(line) fd.close()
生成TFRecord数据:
Python代码:
import sys sys.path.insert(0, '../models/slim/') from datasets import dataset_utils import math import os import tensorflow as tf def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5): fd = open(list_path) lines = [line.split() for line in fd] fd.close() num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS))) with tf.Graph().as_default(): decode_jpeg_data = tf.placeholder(dtype=tf.string) decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3) with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_path = os.path.join(output_dir, 'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS)) tfrecord_writer = tf.python_io.TFRecordWriter(output_path) start_ndx = shard_id * num_per_shard end_ndx = min((shard_id + 1) * num_per_shard, len(lines)) for i in range(start_ndx, end_ndx): sys.stdout.write('\r>> Converting image {}/{} shard {}'.format( i + 1, len(lines), shard_id)) sys.stdout.flush() image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read() image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data}) height, width = image.shape[0], image.shape[1] example = dataset_utils.image_to_tfexample( image_data, b'jpg', height, width, int(lines[i][1])) tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close() sys.stdout.write('\n') sys.stdout.flush() os.system('mkdir -p train') convert_dataset('list_train.txt', 'flower_photos', 'train/') os.system('mkdir -p val') convert_dataset('list_val.txt', 'flower_photos', 'val/')