TensorFlow MNIST入门

经过整合的一段TensorFlow训练数据并测试的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
//-----------------------------以上代码主要是为了导入mnist数据-----
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
//---------------------------定义各个数据的关系,以及数据的张量模型-----
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
//----------------这里我们用交叉熵算法计算来预定于交叉熵变量的值--
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
//-------------梯度下降算法------最小化交叉熵----
init = tf.initialize_all_variables()
//----------初始化所有变量----------------
sess = tf.Session()
sess.run(init)
//----------创建图----启动图-----
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
//---------------训练数据----将 100条mnist的数据读取之后 放到 batch_xs, batch_ys
//--------------- 传递给参数 x 和 参数 y -----------
//我们在上面预定义了X 张量为一个 784 — ∞ 维度的张量,代表着输入,
我们可以输入无数张图片,每张图片784个像素点
// 我们定义了一个特征值 W 张量,是一个 78410 维度的张量

PS: 可以把张量理解成类似C语言数组(虽然不对),可以存放784X10大小的数据 (至于数据是什么后面再介绍,先理解张量)

其实按我的理解,特征值是784个点,有10组,分别代表0-9这10个数字的特征,每个点与数字0-9之间的关联程度,比如图片有784个点

特征值数组有10个 W0,W1,W2…W9(举例:W0都存放着784个点的权值,每个点的权值代表这个点的灰度值与这个点是否是0有关,数字0 正中心的点为空,这个是数字0的特征,我们求的就是这种特征,我们老师以神经元的角度来解释这个,一个个神经元,输入与权值达到一定强度之后,神经元会产生反应,其实我们这个算出来的就是输入与权值的反应强度吧) 我们用tf.matmul(x,W) + b 计算出来各个权值与图片产生的强度之后,用 tf.nn.softmax函数将强度最高的设置为1,其余全部设置为0,这样的结果我们可以与给出的正确答案进行比较,进而反馈给权值。不断的训练权值,使其尽可能的得到W0…W9的特征(即数字0-9的特征,我们其实可以将我们最后求得的W特征数组输出,做成一张类似于热力成像图的那种,不过这里应该叫权值成像图,我自己写了段代码输出的图像与官网给我们的 红蓝成像图类似,代码放在最后面。)

1
2
3
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

// 这里最后对训练的权值进行测试,用另外一组不同的数据来进行测试。

先来看几张特征权值比较 我们看看官方给我们的特征权值图:

这里只贴出0和3的权值分布图,具体的可以自己打印,打印的代码如下:

1
2
3
4
5
6
7
8
9
10
result = sess.run(W)
for i in range(28):
for j in range(28):
if list(result)[(i*28+j)][3] >0.05:
print (1,end="")
elif list(result)[(i*28+j)][3] > - 0.1:
print(" ",end="")
else:
print(2,end="")
print("")

//上面的3代表的是数字3 可以自行修改。至于0.05 和-0.1 是我测试了几次,过滤掉一些无用数据,可以自行调整。

得到的计算结果和训练数据的真实值对比代码

1
2
3
4
batch_xs, batch_ys = mnist.train.next_batch(1)
result = sess.run(y, feed_dict={x: batch_xs})
print (list(result[0]))
print (list(batch_ys[0]))