码迷,mamicode.com
首页 > 编程语言 > 详细

tensorflow使用softmax regression算法实现手写识别

时间:2017-10-27 01:19:28      阅读:295      评论:0      收藏:0      [点我收藏+]

标签:dict   像素   tde   feed   from   输入数据   平均值   优化器   min   

最近在学习黄文坚的TensorFlow书籍,希望对学习做一个总结。

softmax regression算法原理:当我们对一张图片进行预测时,会计算每一个数字的可能性,如3的概率是3%,5的概率是6%,1的概率是80%,则返回1.

TensorFlow版本:0.8.0

# 导入手写识别数据,TensorFlow提供了手写识别库
from
tensorflow.examples.tutorials.mnist import input_data
# 读取手写识别数据 mnist
= input_data.read_data_sets("MNIST_data/", one_hot=True)
# 训练集数据的维度是(55000,784),训练集标签的维度是(55000,10)
# 测试集数据的维度是(10000,784),测试集标签的维度是(10000,10)
# 验证集数据的维度是(5000,784),验证集标签的维度是(5000,10)
# 为什么训练数据的维度是784?因为tensorflow提供的数据集的图片像素是28*28=784
# 为什么标签的维度是10,标签做了处理,每个预期结果变成了只包含0和1的10维数据。
# 例如标签5就表示为[0,0,0,0,0,1,0,0,0,0],这种方法叫one-hot编码
print(mnist.train.images.shape,mnist.train.labels.shape) print(mnist.test.images.shape,mnist.test.labels.shape) print(mnist.validation.images.shape,mnist.validation.labels.shape)

# 导入TensorFlow库
import tensorflow as tf
# 将session注册为默认的session,运算都在session里跑。placeholder为输入数据的地方
# placeholder的第一个参数表示数据类型,第2个参数表示数据的维度,None表示任意长度的数据 sess
=tf.InteractiveSession() x = tf.placeholder(tf.float32,[None,784])
# Variable用于存储参数,它是持久化的,可以长期存在,每次迭代都会更新 # 数据的维度是784,类别的维度经过one-hot编码后变成了10维,所以W的参数为[784,10]
# b为[10]维,W和b全部初始化为0,简单模型的初始值不重要
W
= tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) # softmax函数用于定义softmax regression算法
# matmul用于向量乘法
y
=tf.nn.softmax(tf.matmul(x,W)+b)
# 求损失函数cross-entropy,先定义一个placeholder,输入的真实label
# cross_entropy定义了损失函数的计算方法,通过reduce_sum求熵的和,reduce_mean求每个batch的熵的平均值 y_
=tf.placeholder(tf.float32,[None,10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
# 定义一个优化器,GradientDescentOptimizer为优化器,学习率为0.5,优化目标设定为cross_entropy train_step
= tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 全局参数初始化并执行run tf.initialize_all_variables().run()
# 每次取100个样本,并feed给placeholder,执行1000次,train_step对数据进行训练
for i in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) train_step.run({x:batch_xs,y_:batch_ys})
# 求出概率最大的数字,判断是否与实际标签相符合,y是预测数据,y_是实际数据 correct_prediction
= tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# 求计算精度 accuracy
=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})

总的来说,TensorFlow感觉还是比较简单的,也许这只是个最简单的模型吧。
涉及的概念也只有session,variable,placeholder,GradientDescentOptimizer。
梯度下降等复杂的方法都进行了封装,用python不到30行的代码就实现了手写识别,虽然识别正确率只有92%左右。


 

tensorflow使用softmax regression算法实现手写识别

标签:dict   像素   tde   feed   from   输入数据   平均值   优化器   min   

原文地址:http://www.cnblogs.com/eagle-1024/p/7739711.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!