网站首页 > 技术文章 正文
2. RNN网络实现识别mnist数据集
## 2019.11.1 # 通过建立RNN网络来对mnist进行识别 # mnist数据是28*28*1的图片, 我们建立28个RNN单元, 每个单元输入数据为28 # 我们使用RNN网络最后输出的记忆A, 建立全连接, 识别分类图片 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data tf.set_random_seed(1) # set random seed # 导入数据 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) print(mnist.test.labels.shape) # 学习率 learning_rate = 0.001 # epoch training_items = 100000 # 批次 batch_size = 100 # 每个RNN-Cell 的收入数据 n_inputs = 28 # time_step n_step = 28 # LSTM中的隐藏神经元 n_hidden_units = 128 n_classes = 10 x = tf.placeholder(tf.float32, [None, n_step, n_inputs]) y = tf.placeholder(tf.float32, [None, n_classes]) weight = { 'in': tf.Variable(tf.random.normal([n_inputs, n_hidden_units])), 'out': tf.Variable(tf.random.normal([n_hidden_units, n_classes])) } bais = { 'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units])), 'out': tf.Variable(tf.constant(0.1, shape=[n_classes])) } def Rnn(X, weight, bais): X = tf.reshape(X, [-1, n_inputs]) X_in = tf.matmul(X, weight['in']) + bais['in'] X_in = tf.reshape(X_in, [-1, n_step, n_hidden_units]) ## n_hidden_units 控制神经元的个数, 也就是状态a_perv(n_hidden_units,) lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True) ## batch_size 控制了状态a_perv(batch_size, n_hidden_units) 于吴恩达讲的结构相反, 意思一样 init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) ##X_in 中的n_step控制了LSTM的深度, 对于每一个cell我们输入[batch_size, n_hidden_units] ## outputs记录了每一个cell的输出, final_state记录最后的输出对于状态 ## tf.nn.dynamic_rnn中的time_major参数会针对不同inputs 格式有不同的值. ## 如果inputs为(batches, steps, inputs) == > time_major = False; ## 如果inputs为(steps, batches, inputs) == > time_major = True; outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False) results = tf.matmul(final_state[1], weight['out']) + bais['out'] return results pred = Rnn(x, weight, bais) ## softmax and cross_entroy cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) train = tf.train.AdamOptimizer(learning_rate).minimize(cost) # tf.argmax(pred, axis=1) correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) re_loss=[] re_acc_train=[] re_acc_test=[] step = 0 t = 0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) while step * batch_size < training_items: batch_xs, batch_ys = mnist.train.next_batch(batch_size) batch_xs = batch_xs.reshape([batch_size, n_step, n_inputs]) dict = {x: batch_xs, y: batch_ys} sess.run(train, feed_dict=dict) if step % 5 == 0: batch_txs, batch_tys = mnist.test.next_batch(100) batch_txs = batch_txs.reshape([100, n_step, n_inputs]) re_loss.append(sess.run(cost, feed_dict=dict)) re_acc_train.append(sess.run(accuracy, feed_dict=dict)) re_acc_test.append(sess.run(accuracy, feed_dict={x: batch_txs, y: batch_tys})) t += 1 step += 1 plt.subplot(1,1,1) plt.plot(np.linspace(0, step, t), re_loss, c='blue', label='Loss') plt.plot(np.linspace(0, step, t), re_acc_train, c='red', label='Accuracy_Train_Sets') plt.plot(np.linspace(0, step, t), re_acc_test, c='green', label='Accuracy_Test_Set') plt.legend(loc='best') plt.show()
- 上一篇: 要为学习神经网络奠定基础,你需要认真读读R深度学习
- 下一篇: AI攻城狮,你需要那个数据集的种子吗?
猜你喜欢
- 2024-09-24 行业篇:自动驾驶场景下的数据标注类别分享
- 2024-09-24 AI预标注,人工智能基础数据服务行业的新引擎丨曼孚科技
- 2024-09-24 基于Movielens-1M数据集和相似性矩阵实现的电影推荐算法(附源码)
- 2024-09-24 AAAI 2022 | GAN的结构有“指纹”吗?从伪造图像溯源生成网络结构
- 2024-09-24 人工智能时代,数据标注产业将迎来黄金时期?丨曼孚科技
- 2024-09-24 R数据分析:如何用R做多重插补,实例操练
- 2024-09-24 AI攻城狮,你需要那个数据集的种子吗?
- 2024-09-24 要为学习神经网络奠定基础,你需要认真读读R深度学习
- 2024-09-24 CL0P组织利用Seed传输窃取的敏感数据 (上)
- 2024-09-24 详解SEED数据服务平台(5):批改与批注
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)