码迷,mamicode.com
首页 > 其他好文 > 详细

看一下代码效果

时间:2017-08-22 20:51:57      阅读:124      评论:0      收藏:0      [点我收藏+]

标签:dict   int   nbsp   oss   cpp   global   data   remove   feed   

import os,math
os.environ[TF_CPP_MIN_LOG_LEVEL]=2
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
x_data = np.array([[0,0],[0,1],[1,0],[1,1]])
y_data = np.array([[0],[1],[1],[0]])
xs = tf.placeholder(tf.float32,[None,2])
ys = tf.placeholder(tf.float32,[None,1])
Weights_1 = tf.Variable(tf.random_normal([2,2])+1)
biases_1 = tf.Variable(tf.zeros([1,2])+0.1)
Wx_plus_b = tf.add(tf.matmul(xs,Weights_1),biases_1)
l1=tf.nn.tanh(Wx_plus_b)
Weights_2 = tf.Variable(tf.random_normal([2,1])+1)
biases_2 = tf.Variable(tf.zeros([1,1])+0.1)
prediction = tf.add(tf.matmul(l1,Weights_2),biases_2)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
x_data1 = [0,0,1,1]
y_data1 = [0,1,0,1]
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data1,y_data1)
x1 = x2 = np.arange(-1, 1, 0.1)
plt.ion()
plt.show()
for i in range(10000):
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if(i%500==0):
        _loss = sess.run(loss,feed_dict={xs:x_data,ys:y_data})
        print(_loss)
        w1 = sess.run(Weights_1)
        w2 = sess.run(Weights_2)
        b1 = sess.run(biases_1)
        b2 = sess.run(biases_2)
        y1 = -w1[0][0]/w1[1][0]*x1-b1[0][0]/w1[1][0]
        y2 = -w1[0][1]/w1[1][1]*x2-b1[0][1]/w1[1][1]
        try:
            ax.lines.remove(lines0[0])
            ax.lines.remove(lines1[0])
        except:
            pass
        lines0 = ax.plot(x1, y1)
        lines1 = ax.plot(x2, y2)
        plt.pause(0.1)
prediction_value = sess.run(prediction,feed_dict={xs:x_data})
print(prediction_value)
print(w1,b1)
print(w1[0][0],".x1+",w1[1][0],".x2+",b1[0][0],"=0")
print(w1[0][1],".x1+",w1[1][1],".x2+",b1[0][1],"=0")

 

看一下代码效果

标签:dict   int   nbsp   oss   cpp   global   data   remove   feed   

原文地址:http://www.cnblogs.com/0xcafe/p/7413428.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!