700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > TensorFlow精进之路(十四):RNN训练MNIST数据集

TensorFlow精进之路(十四):RNN训练MNIST数据集

时间:2020-09-25 04:41:19

相关推荐

TensorFlow精进之路(十四):RNN训练MNIST数据集

1、概述

前面介绍了RNN,这一节就用tensorflow的RNN来训练MNIST数据集,看看准确率如何。

2、代码实现

2.1、导入数据集

# encoding:utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#下载并导入MNIST数据集mnist = input_data.read_data_sets("mnist_data/", one_hot=True)

2.2、定义基本参数

#MNIST数据集每张图片的大小是28×28,# 这里将输入x分成28个时间段,每个时间段的内容为28个值n_input = 28n_steps = 28#隐含层n_hidden = 128#0~9共10个分类n_classes = 10

2.3、定义占位符

# 定义占位符x = tf.placeholder("float", [None, n_steps, n_input])y = tf.placeholder("float", [None, n_classes])

2.4、数据转换

#将原始的28×28数据调成具有28个时间段的list,每个list是一个1×28数组,# 将这28个时序送入RNN中x1 = tf.unstack(x, n_steps, 1)

2.5、定义RNN

# cell类,这里使用LSTM,BasicLSTMCell是LSTM的basic版本lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)#通过cell类构建RNN,这里使用静态RNNoutputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)

2.6、全连接层

#全连接层pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)

2.7、定义学习率等

#定义学习率,训练次数,batch长度learning_rate = 0.001training_iters = 20000batch_size = 10

2.8、定义损失和优化器

# 定义损失和优化器cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

2.9、测试准确率的方法

# 测试准确率correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

2.10、开始训练

# 启动sessionwith tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())step = 1#开始训练while step * batch_size < training_iters:#批量获取MNIST数据batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, n_steps, n_input))sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % 1000 == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print ("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print (" Finished!")

2.11、计算准确率

# 计算准确率test_len = 128test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print ("Testing Accuracy:", \sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

2.12、训练结果

Iter14000,MinibatchLoss=0.348060,TrainingAccuracy=0.90000

Iter15000,MinibatchLoss=0.086765,TrainingAccuracy=1.00000

Iter16000,MinibatchLoss=0.158794,TrainingAccuracy=0.90000

Iter17000,MinibatchLoss=0.087399,TrainingAccuracy=1.00000

Iter18000,MinibatchLoss=0.046167,TrainingAccuracy=1.00000

Iter19000,MinibatchLoss=0.026566,TrainingAccuracy=1.00000

Finished!

('TestingAccuracy:',0.953125)

准确率在95%左右,不及CNN。

2.13、完整代码

# encoding:utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#下载并导入MNIST数据集mnist = input_data.read_data_sets("mnist_data/", one_hot=True)#MNIST数据集每张图片的大小是28×28,# 这里将输入x分成28个时间段,每个时间段的内容为28个值n_input = 28n_steps = 28#隐含层n_hidden = 128#0~9共10个分类n_classes = 10# 定义占位符x = tf.placeholder("float", [None, n_steps, n_input])y = tf.placeholder("float", [None, n_classes])#将原始的28×28数据调成具有28个时间段的list,每个list是一个1×28数组,# 将这28个时序送入RNN中x1 = tf.unstack(x, n_steps, 1)# cell类,这里使用LSTM,BasicLSTMCell是LSTM的basic版本lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)#通过cell类构建RNN,这里使用静态RNNoutputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)#全连接层pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)#定义学习率,训练次数,batch长度learning_rate = 0.001training_iters = 20000batch_size = 10# 定义损失和优化器cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# 测试准确率correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# 启动sessionwith tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())step = 1#开始训练while step * batch_size < training_iters:#批量获取MNIST数据batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, n_steps, n_input))sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % 100 == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print ("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print (" Finished!")# 计算准确率test_len = 128test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print ("Testing Accuracy:", \sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

总结

这个例子还可以用GRU网络或者动态RNN网络来训练,tensorflow都提供了相应的API,稍微改一下即可,但是学习时间太少,还是抓紧时间往后学吧。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。