码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow暑期实践——基于单个神经元的手写数字识别(全部代码)

时间:2020-07-10 10:00:02      阅读:62      评论:0      收藏:0      [点我收藏+]

标签:learning   log   coding   writer   reduce   utf-8   写入   visible   des   

# coding: utf-8

import tensorflow as tf 
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
print(tf.__version__)
print(tf.test.is_gpu_available())


from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

tf.reset_default_graph() #清除default graph和不断增加的节点

x = tf.placeholder(tf.float32, [None, 784]) # mnist 中每张图片共有28*28=784个像素点
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 一共10个数字=> 10 个类别


norm = tf.random_normal([100]) #生成100个随机数
with tf.Session() as sess:
    norm_data=norm.eval()
print(norm_data[:10])                  #打印前10个随机数


import matplotlib.pyplot as plt
plt.hist(norm_data)
plt.show()

W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) 


forward=tf.matmul(x, W) + b # 前向输出
tf.summary.histogram(forward,forward)#将前向输出值以直方图显示

pred = tf.nn.softmax(forward) # Softmax分类

train_epochs = 30
batch_size = 100
total_batch= int(mnist.train.num_examples/batch_size)
display_step = 1
learning_rate=0.01

loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) # 交叉熵
tf.summary.scalar(loss, loss_function)#将损失以标量显示

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #梯度下降

# 检查预测类别tf.argmax(pred, 1)与实际类别tf.argmax(y, 1)的匹配情况
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# 准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将布尔值转化为浮点数,并计算平均值

tf.summary.scalar(accuracy, accuracy)#将准确率以标量显示


sess = tf.Session() #声明会话
init = tf.global_variables_initializer() # 变量初始化
sess.run(init)

merged_summary_op = tf.summary.merge_all()#合并所有summary
writer = tf.summary.FileWriter(log/mnist_single_neuron, sess.graph) #创建写入符

# 开始训练
for epoch in range(train_epochs ):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size)# 读取批次数据
        sess.run(optimizer,feed_dict={x: xs,y: ys}) # 执行批次训练
        
        #生成summary
        summary_str = sess.run(merged_summary_op,feed_dict={x: xs,y: ys})
        writer.add_summary(summary_str, epoch)#将summary 写入文件
        #total_batch个批次训练完成后,使用验证数据计算误差与准确率   
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
    # 打印训练过程中的详细信息
    if (epoch+1) % display_step == 0:
        print("Train Epoch:", %02d % (epoch+1), "Loss=", "{:.9f}".format(loss)," Accuracy=","{:.4f}".format(acc))

print("Train Finished!")  


print("Test Accuracy:", sess.run(accuracy,
                           feed_dict={x: mnist.test.images, y: mnist.test.labels}))


prediction_result=sess.run(tf.argmax(pred,1), # 由于pred预测结果是one-hot编码格式,所以需要转换为0~9数字
                           feed_dict={x: mnist.test.images })

prediction_result[0:10] #查看预测结果中的前10项


import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(10, 12)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap=binary)
            
        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()


plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,0)


# 预测结果
plot_images_labels_prediction(mnist.test.images,
                             mnist.test.labels,
                             prediction_result,10,25)

 

Tensorflow暑期实践——基于单个神经元的手写数字识别(全部代码)

标签:learning   log   coding   writer   reduce   utf-8   写入   visible   des   

原文地址:https://www.cnblogs.com/caiyishuai/p/13277407.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!