1.创建tfrecord
tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入 tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:
def get_tfrecords_example(feature, label): tfrecords_features = {} feat_shape = feature.shape tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label)) return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example exmp_serial = exmp.SerializeToString() #Example序列化 tfrecord_wrt.write(exmp_serial) #写入tfrecord文件 tfrecord_wrt.close() #写完后关闭tfrecord的writer
代码汇总:
import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets mnist = read_data_sets("MNIST_data/", one_hot=True) #把数据写入Example def get_tfrecords_example(feature, label): tfrecords_features = {} feat_shape = feature.shape tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape))) tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label)) return tf.train.Example(features=tf.train.Features(feature=tfrecords_features)) #把所有数据写入tfrecord文件 def make_tfrecord(data, outf_nm='mnist-train'): feats, labels = data outf_nm += '.tfrecord' tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm) ndatas = len(labels) for inx in range(ndatas): exmp = get_tfrecords_example(feats[inx], labels[inx]) exmp_serial = exmp.SerializeToString() tfrecord_wrt.write(exmp_serial) tfrecord_wrt.close() import random nDatas = len(mnist.train.labels) inx_lst = range(nDatas) random.shuffle(inx_lst) random.shuffle(inx_lst) ntrains = int(0.85*nDatas) # make training set data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \ [mnist.train.labels[i] for i in inx_lst[:ntrains]]) make_tfrecord(data, outf_nm='mnist-train') # make validation set data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \ [mnist.train.labels[i] for i in inx_lst[ntrains:]]) make_tfrecord(data, outf_nm='mnist-val') # make test set data = (mnist.test.images, mnist.test.labels) make_tfrecord(data, outf_nm='mnist-test')