使用新的数据集API(作为TF 1.4版本的一部分发布)来加快整个过程
读取CSV文件的步骤:
1)读取CSV文件名
2) 通过提供CSV文件名创建TextLineDataset
3) 为解码创建Parse函数,并对输入数据执行任何预处理工作
4) 使用前面步骤中创建的数据集创建批处理、重复(epoch编号)和无序处理
5) 创建迭代器,将所需的输入作为批处理(即小批量)
Eg代码:from matplotlib.image import imread
def input_model_function():
csv_filename =['images.txt']
dataset = tf.data.TextLineDataset(csv_filename)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(20)# you can use any number of batching
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
batch_images, batch_labels = sess.run(iterator.get_next())
return {'x':batch_images}, batch_labels
def _parse_function(line):
image, labels= tf.decode_csv(line,record_defaults=[[""], [0]])
# Decode the raw bytes so it becomes a tensor with type.
image = imread(image)# give full path name of image
return image, labels
最后将批处理数据集输入模型(使用任何预先制作的估计器或自定义估计器API创建)