码迷,mamicode.com
首页 > 其他好文 > 详细

简单ANN实现二次曲线拟合

时间:2017-08-23 10:25:07      阅读:372      评论:0      收藏:0      [点我收藏+]

标签:des   mini   import   turn   new   logs   blog   梯度下降   plot   

代码:

 1 import os
 2 os.environ[TF_CPP_MIN_LOG_LEVEL]=2
 3 import tensorflow as tf
 4 import numpy as np
 5 import matplotlib.pyplot as plt
 6 //定义层
 7 def add_layer(inputs, in_size, out_size, activation_function=None):
 8     Weights = tf.Variable(tf.random_normal([in_size,out_size]),name=W)
 9     biases = tf.Variable(tf.zeros([1,out_size])+0.1,name=b)
10     Wx_plus_b = tf.add(tf.matmul(inputs,Weights),biases)
11     if(activation_function is None):
12         outputs = Wx_plus_b
13     else:
14         outputs = activation_function(Wx_plus_b)
15     return outputs
16 //含有噪声的数据
17 x_data = np.linspace(-1,1,300)[:,np.newaxis]
18 noise = np.random.normal(0,0.05,x_data.shape)
19 y_data = np.square(x_data)-0.5+noise
20 
21 
22 xs = tf.placeholder(tf.float32,[None,1],name=x_input)
23 ys = tf.placeholder(tf.float32,[None,1],name=y_input)
24 
25 
26 //构建输入层层(1),隐藏层(4),输出层(1) 梯度下降优化,学习率为0.1
27 l1 = add_layer(xs,1,4,activation_function=tf.nn.tanh)
28 prediction = add_layer(l1,4,1,activation_function=None)
29 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
30 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
31 
32 
33 //初始化
34 init = tf.global_variables_initializer()
35 sess = tf.Session()
36 sess.run(init)
37 
38 //绘图
39 fig = plt.figure()
40 ax = fig.add_subplot(1,1,1)
41 ax.scatter(x_data,y_data)
42 plt.ion()
43 plt.show()
44 for i in range(10000):
45     sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
46     if(i%50==0):
47         _loss = sess.run(loss,feed_dict={xs:x_data,ys:y_data})
48         print(_loss)
49         if(_loss<0.005):
50             exit()
51         try:
52             ax.lines.remove(lines[0])
53         except:
54             pass
55         prediction_value = sess.run(prediction,feed_dict={xs:x_data})
56         lines = ax.plot(x_data,prediction_value,r-,lw=5)
57         plt.pause(0.1)

 

简单ANN实现二次曲线拟合

标签:des   mini   import   turn   new   logs   blog   梯度下降   plot   

原文地址:http://www.cnblogs.com/0xcafe/p/7414436.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!