码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

时间:2018-11-12 20:56:50      阅读:260      评论:0      收藏:0      [点我收藏+]

标签:一点   pat   lob   无法   save   占位符   imp   现在   ==   

好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前,不知道我大学的那些小伙伴们现在都怎么样了,考研的刚刚希望他考上,实习的菜头希望他早日脱离苦海,小瑞哥希望他早日出成果,范爷熊健研究生一定要过的开心啊!天哥也哥早日结婚领证!那些回不去的曾经的快乐的时光,你们都还好吗!

 

最近开始接触Tensorflow,可能是论文里用的是这个框架吧,其实我还是觉得pytorch更方便好用一些,仔细读了最简单的Mnist手写识别程序,觉得大同小异,关键要理解Tensorflow的思想,文末就写一下自己看交叉熵的感悟,絮叨了这么多开始写点代码吧!  1 # -*- coding: utf-8 -*-

  2 """
  3 Created on Sun Nov 11 16:14:38 2018
  4 
  5 @author: Yang
  6 """
  7 
  8 import tensorflow as tf 
  9 from tensorflow.examples.tutorials.mnist import input_data 
 10 
 11 mnist = input_data.read_data_sets("/MNIST_data",one_hot=True) #从input_data中读取数据集,使用one_hot编码
 12 
 13 import pylab #画图模块
 14 
 15 tf.reset_default_graph()#重置一下图 图代表了一个运算过程,包含了许多Variable和op,如果不重置一下图的话,可能会因为某些工具重复调用变量而报错
 16 
 17 x = tf.placeholder(tf.float32,[None,784])#占位符,方便用feed_dict进行注入操作
 18 y = tf.placeholder(tf.float32,[None,10])#占位符,方便用feed_dict进行注入操作
20 21 W = tf.Variable(tf.random_normal([784,10]))#要学习的参数统一用Variable来定义,这样方便进行调整更新 22 b = tf.Variable(tf.zeros([10])) 23 24 25 #construct the model 26 pred = tf.nn.softmax(tf.matmul(x,W) + b) 27 28 cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 29 30 learning_rate = 0.01 31 32 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 33 34 #set parameters about thee model 35 training_epoch = 25 36 batch_size = 100 37 display_step = 1 38 saver = tf.train.Saver() 39 model_path = "log/kerwinsmodel.ckpt" 40 41 #start the session 42 43 with tf.Session() as sess : 44 sess.run(tf.global_variables_initializer()) 45 46 for epoch in range(training_epoch): 47 avg_cost = 0 48 total_batch = int(mnist.train.num_examples/batch_size) 49 print(total_batch) 50 for i in range(total_batch): 51 batch_xs,batch_ys = mnist.train.next_batch(batch_size) 52 53 _,c = sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys}) 54 55 avg_cost += c/ total_batch 56 if (epoch +1 ) % display_step ==0: 57 print("Epoch:",%04d %(epoch+1),"cost=","{:.9f}".format(avg_cost)) 58 59 print("Finish!") 60 61 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 62 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 63 print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels})) 64 65 save_path = saver.save(sess,model_path) 66 print("Model saved in file: %s" % save_path) 67 # 68 69 70 #读取模型程序 71 72 print("Starting 2nd session...") 73 with tf.Session() as sess: 74 sess.run(tf.global_variables_initializer()) 75 saver.restore(sess,model_path) 76 77 #测试model 78 correct_prediction = tf.equal(tf.arg_max(pred,1),tf.argmax(y,1)) 79 #计算准确率 80 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 81 print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels})) 82 83 output = tf.argmax(pred,1) 84 batch_xs,batch_ys = mnist.train.next_batch(2) 85 outputval,predv = sess.run([output,pred],feed_dict={x:batch_xs}) 86 print(outputval,predv,batch_ys) 87 88 im = batch_xs[0] 89 im = im.reshape(-1,28) 90 pylab.imshow(im) 91 pylab.show() 92 93 im = batch_xs[1] 94 im = im.reshape(-1,28) 95 pylab.imshow(im) 96 pylab.show() 97 98 99 100

 

#占位符,方便用feed_dict进行注入操作

Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

标签:一点   pat   lob   无法   save   占位符   imp   现在   ==   

原文地址:https://www.cnblogs.com/kerwins-AC/p/9948841.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!