全部存入一个TFrecords文件,然后读取并显示第一张。
不多写了,直接贴代码。
from PIL import Image import numpy as np import matplotlib.pyplot as plt import tensorflow as tf IMAGE_PATH = 'test/' tfrecord_file = IMAGE_PATH + 'test.tfrecord' writer = tf.python_io.TFRecordWriter(tfrecord_file) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def get_image_binary(filename): """ You can read in the image using tensorflow too, but it's a drag since you have to create graphs. It's much easier using Pillow and NumPy """ image = Image.open(filename) image = np.asarray(image, np.uint8) shape = np.array(image.shape, np.int32) return shape, image.tobytes() # convert image to raw data bytes in the array. def write_to_tfrecord(label, shape, binary_image, tfrecord_file): """ This example is to write a sample to TFRecord file. If you want to write more samples, just use a loop. """ # write label, shape, and image content to the TFRecord file example = tf.train.Example(features=tf.train.Features(feature={ 'label': _int64_feature(label), 'h': _int64_feature(shape[0]), 'w': _int64_feature(shape[1]), 'c': _int64_feature(shape[2]), 'image': _bytes_feature(binary_image) })) writer.write(example.SerializeToString()) def write_tfrecord(label, image_file, tfrecord_file): shape, binary_image = get_image_binary(image_file) write_to_tfrecord(label, shape, binary_image, tfrecord_file) # print(shape) def main(): # assume the image has the label Chihuahua, which corresponds to class number 1 label = [1,2] image_files = [IMAGE_PATH + 'a.jpg', IMAGE_PATH + 'b.jpg'] for i in range(2): write_tfrecord(label[i], image_files[i], tfrecord_file) writer.close() batch_size = 2 filename_queue = tf.train.string_input_producer([tfrecord_file]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'h': tf.FixedLenFeature([], tf.int64), 'w': tf.FixedLenFeature([], tf.int64), 'c': tf.FixedLenFeature([], tf.int64), 'image': tf.FixedLenFeature([], tf.string), }) h = tf.cast(img_features['h'], tf.int32) w = tf.cast(img_features['w'], tf.int32) c = tf.cast(img_features['c'], tf.int32) image = tf.decode_raw(img_features['image'], tf.uint8) image = tf.reshape(image, [h, w, c]) label = tf.cast(img_features['label'],tf.int32) label = tf.reshape(label, [1]) # image = tf.image.resize_images(image, (500,500)) #image, label = tf.train.batch([image, label], batch_size= batch_size) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) image, label=sess.run([image, label]) coord.request_stop() coord.join(threads) print(label) plt.figure() plt.imshow(image) plt.show() if __name__ == '__main__': main()