计算机系统应用教程网站

网站首页 > 技术文章 正文

RNN网络实现识别mnist数据集

btikc 2024-09-24 08:26:53 技术文章 24 ℃ 0 评论

2. RNN网络实现识别mnist数据集

## 2019.11.1
# 通过建立RNN网络来对mnist进行识别
# mnist数据是28*28*1的图片, 我们建立28个RNN单元, 每个单元输入数据为28
# 我们使用RNN网络最后输出的记忆A, 建立全连接, 识别分类图片
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(1) # set random seed

# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print(mnist.test.labels.shape)

# 学习率
learning_rate = 0.001
# epoch
training_items = 100000
# 批次
batch_size = 100
# 每个RNN-Cell 的收入数据
n_inputs = 28
# time_step
n_step = 28
# LSTM中的隐藏神经元
n_hidden_units = 128
n_classes = 10

x = tf.placeholder(tf.float32, [None, n_step, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])
weight = {
 'in': tf.Variable(tf.random.normal([n_inputs, n_hidden_units])),
 'out': tf.Variable(tf.random.normal([n_hidden_units, n_classes]))
}
bais = {
 'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units])),
 'out': tf.Variable(tf.constant(0.1, shape=[n_classes]))
}

def Rnn(X, weight, bais):
 X = tf.reshape(X, [-1, n_inputs])

 X_in = tf.matmul(X, weight['in']) + bais['in']
 X_in = tf.reshape(X_in, [-1, n_step, n_hidden_units])

 ## n_hidden_units 控制神经元的个数, 也就是状态a_perv(n_hidden_units,)
 lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)

 ## batch_size 控制了状态a_perv(batch_size, n_hidden_units) 于吴恩达讲的结构相反, 意思一样
 init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

 ##X_in 中的n_step控制了LSTM的深度, 对于每一个cell我们输入[batch_size, n_hidden_units]
 ## outputs记录了每一个cell的输出, final_state记录最后的输出对于状态
 ## tf.nn.dynamic_rnn中的time_major参数会针对不同inputs 格式有不同的值.
 ## 如果inputs为(batches, steps, inputs) == > time_major = False;
 ## 如果inputs为(steps, batches, inputs) == > time_major = True;
 outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)

 results = tf.matmul(final_state[1], weight['out']) + bais['out']
 return results

pred = Rnn(x, weight, bais)

## softmax and cross_entroy
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train = tf.train.AdamOptimizer(learning_rate).minimize(cost)

# tf.argmax(pred, axis=1)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

re_loss=[]
re_acc_train=[]
re_acc_test=[]
step = 0

t = 0
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 while step * batch_size < training_items:
 batch_xs, batch_ys = mnist.train.next_batch(batch_size)

 batch_xs = batch_xs.reshape([batch_size, n_step, n_inputs])
 dict = {x: batch_xs, y: batch_ys}
 sess.run(train, feed_dict=dict)
 if step % 5 == 0:
 batch_txs, batch_tys = mnist.test.next_batch(100)
 batch_txs = batch_txs.reshape([100, n_step, n_inputs])
 re_loss.append(sess.run(cost, feed_dict=dict))
 re_acc_train.append(sess.run(accuracy, feed_dict=dict))
 re_acc_test.append(sess.run(accuracy, feed_dict={x: batch_txs, y: batch_tys}))
 t += 1
 step += 1

plt.subplot(1,1,1)
plt.plot(np.linspace(0, step, t), re_loss, c='blue', label='Loss')
plt.plot(np.linspace(0, step, t), re_acc_train, c='red', label='Accuracy_Train_Sets')
plt.plot(np.linspace(0, step, t), re_acc_test, c='green', label='Accuracy_Test_Set')
plt.legend(loc='best')
plt.show()

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表