码迷,mamicode.com
首页 > 其他好文 > 详细

基于tensorflow的简单线性回归模型

时间:2018-04-27 22:58:05      阅读:218      评论:0      收藏:0      [点我收藏+]

标签:variable   ict   otl   ace   /usr   python3   one   mat   work   

#!/usr/local/bin/python3

 

##ljj [1]

##linear regression model

 

import tensorflow as tf

import matplotlib.pyplot as plt

 

x_ = [11,14,22,29,32,40,44,55,59,60,69,77]

y_res = [123,135,155,167,177,189,200,240,250,255,277,298]

  

w = tf.Variable(tf.ones([1]),dtype="float32")

b = tf.Variable(tf.ones([1]),dtype="float32")

 

 

y_predict = tf.placeholder(tf.float32)

x = tf.placeholder(tf.float32)

 

 

with tf.Session() as sess:

    y_predict = w*x+b

    loss = tf.reduce_mean(tf.square(y_res-y_predict))

 

    train = tf.train.AdamOptimizer(0.7).minimize(loss)

  

    sess.run(tf.global_variables_initializer())

    

    for i in range(len(x_)):

        # train.run(feed_dict={x:x_[i], y_predict:y_res[i]})

        w_,b_,_= sess.run([w,b,train],feed_dict={x:x_[i], y_predict:y_res[i]})

 

    print(w_,b_)

 

plt.plot(x_,y_res,‘.‘)

plt.plot(x_,x_*w_+b_,‘-‘)

plt.show()

 

主机环境:MacbookPro,tensoflow版本1.4

输出结果:

ljjdeMBP:linear_regression lingjiajun$ ./linear_regression.py 

/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:205: RuntimeWarning: compiletime version 3.5 of module ‘tensorflow.python.framework.fast_tensor_util‘ does not match runtime version 3.6

  return f(*args, **kwds)

2018-04-27 22:20:05.963003: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA

[ 4.78998518] [ 5.67698431]

-------以上输出分别是拟合出的Weight,Bias值。

技术分享图片

基于tensorflow的简单线性回归模型

标签:variable   ict   otl   ace   /usr   python3   one   mat   work   

原文地址:https://www.cnblogs.com/lingjiajun/p/8964889.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!