码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习-MNIST数据集-神经网络

时间:2019-03-23 22:40:44      阅读:266      评论:0      收藏:0      [点我收藏+]

标签:tps   csdn   base   建模   val   http   data   方式   实现   

 1 #设置随机种子
 2 seed = 7 
 3 numpy.random.seed(seed)
 4 
 5 #加载数据
 6 (X_train,y_train),(X_test,y_test) = mnist.load_data() 
 7 #print(X_train.shape[0])
 8 
 9 #数据集是3维的向量(instance length,width,height).对于多层感知机,模型的输入是二维的向量,因此这里需要将数据集reshape,即将28*28的向量转成784长度的数组。可以用numpy的reshape函数轻松实现这个过程。
10 num_pixels = X_train.shape[1] * X_train.shape[2] 
11 X_train = X_train.reshape(X_train.shape[0],num_pixels).astype(float32)
12 X_test = X_test.reshape(X_test.shape[0],num_pixels).astype(float32)
13 
14 #给定的像素的灰度值在0-255,为了使模型的训练效果更好,通常将数值归一化映射到0-1
15 X_train = X_train / 255
16 X_test = X_test / 255
17 # one hot encoding
18 y_train = np_utils.to_categorical(y_train)
19 y_test = np_utils.to_categorical(y_test)
20 num_classes = y_test.shape[1]
21 
22 # 搭建神经网络模型了,创建一个函数,建立含有一个隐层的神经网络
23 def baseline_model():
24     model = Sequential() # 建立一个Sequential模型,然后一层一层加入神经元
25     # 第一步是确定输入层的数目正确:在创建模型时用input_dim参数确定。例如,有784个个输入变量,就设成num_pixels。
26     #全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数。这里的初始化方法是0到0.05的连续型均匀分布(uniform),Keras的默认方法也是这个。也可以用高斯分布进行初始化(normal)。
27     # 具体定义参考:https://cnbeining.github.io/deep-learning-with-python-cn/3-multi-layer-perceptrons/ch7-develop-your-first-neural-network-with-keras.html
28     model.add(Dense(num_pixels,input_dim=num_pixels,kernel_initializer=normal,activation=relu))
29     model.add(Dense(num_classes,kernel_initializer=normal,activation=softmax))
30     model.compile(loss=categorical_crossentropy,optimizer=adam,metrics=[accuracy])
31     return model
32 
33 model = baseline_model()
34 #model.fit() 函数每个参数的意义参考:https://blog.csdn.net/a1111h/article/details/82148497
35 model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=10,batch_size=200,verbose=2) 
36 # 1、模型概括打印
37 model.summary()
38 
39 scores = model.evaluate(X_test,y_test,verbose=0) #model.evaluate 返回计算误差和准确率
40 print(scores)
41 print("Base Error:%.2f%%"%(100-scores[1]*100))

 

机器学习-MNIST数据集-神经网络

标签:tps   csdn   base   建模   val   http   data   方式   实现   

原文地址:https://www.cnblogs.com/david2018098/p/10585856.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!