1、概述
前面介绍了RNN,这一节就用tensorflow的RNN来训练MNIST数据集,看看准确率如何。
2、代码实现
2.1、导入数据集
# encoding:utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#下载并导入MNIST数据集mnist = input_data.read_data_sets("mnist_data/", one_hot=True)
2.2、定义基本参数
#MNIST数据集每张图片的大小是28×28,# 这里将输入x分成28个时间段,每个时间段的内容为28个值n_input = 28n_steps = 28#隐含层n_hidden = 128#0~9共10个分类n_classes = 10
2.3、定义占位符
# 定义占位符x = tf.placeholder("float", [None, n_steps, n_input])y = tf.placeholder("float", [None, n_classes])
2.4、数据转换
#将原始的28×28数据调成具有28个时间段的list,每个list是一个1×28数组,# 将这28个时序送入RNN中x1 = tf.unstack(x, n_steps, 1)
2.5、定义RNN
# cell类,这里使用LSTM,BasicLSTMCell是LSTM的basic版本lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)#通过cell类构建RNN,这里使用静态RNNoutputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)
2.6、全连接层
#全连接层pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)
2.7、定义学习率等
#定义学习率,训练次数,batch长度learning_rate = 0.001training_iters = 20000batch_size = 10
2.8、定义损失和优化器
# 定义损失和优化器cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
2.9、测试准确率的方法
# 测试准确率correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
2.10、开始训练
# 启动sessionwith tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())step = 1#开始训练while step * batch_size < training_iters:#批量获取MNIST数据batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, n_steps, n_input))sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % 1000 == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print ("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print (" Finished!")
2.11、计算准确率
# 计算准确率test_len = 128test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print ("Testing Accuracy:", \sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
2.12、训练结果
Iter14000,MinibatchLoss=0.348060,TrainingAccuracy=0.90000
Iter15000,MinibatchLoss=0.086765,TrainingAccuracy=1.00000
Iter16000,MinibatchLoss=0.158794,TrainingAccuracy=0.90000
Iter17000,MinibatchLoss=0.087399,TrainingAccuracy=1.00000
Iter18000,MinibatchLoss=0.046167,TrainingAccuracy=1.00000
Iter19000,MinibatchLoss=0.026566,TrainingAccuracy=1.00000
Finished!
('TestingAccuracy:',0.953125)
准确率在95%左右,不及CNN。
2.13、完整代码
# encoding:utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#下载并导入MNIST数据集mnist = input_data.read_data_sets("mnist_data/", one_hot=True)#MNIST数据集每张图片的大小是28×28,# 这里将输入x分成28个时间段,每个时间段的内容为28个值n_input = 28n_steps = 28#隐含层n_hidden = 128#0~9共10个分类n_classes = 10# 定义占位符x = tf.placeholder("float", [None, n_steps, n_input])y = tf.placeholder("float", [None, n_classes])#将原始的28×28数据调成具有28个时间段的list,每个list是一个1×28数组,# 将这28个时序送入RNN中x1 = tf.unstack(x, n_steps, 1)# cell类,这里使用LSTM,BasicLSTMCell是LSTM的basic版本lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)#通过cell类构建RNN,这里使用静态RNNoutputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x1, dtype=tf.float32)#全连接层pred = tf.contrib.layers.fully_connected(outputs[-1], n_classes, activation_fn=None)#定义学习率,训练次数,batch长度learning_rate = 0.001training_iters = 20000batch_size = 10# 定义损失和优化器cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# 测试准确率correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# 启动sessionwith tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())step = 1#开始训练while step * batch_size < training_iters:#批量获取MNIST数据batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, n_steps, n_input))sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})if step % 100 == 0:# 计算批次数据的准确率acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})# Calculate batch lossloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})print ("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \"{:.6f}".format(loss) + ", Training Accuracy= " + \"{:.5f}".format(acc))step += 1print (" Finished!")# 计算准确率test_len = 128test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print ("Testing Accuracy:", \sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
总结
这个例子还可以用GRU网络或者动态RNN网络来训练,tensorflow都提供了相应的API,稍微改一下即可,但是学习时间太少,还是抓紧时间往后学吧。