码迷,mamicode.com
首页 > 其他好文 > 详细

线性回归

时间:2019-04-29 23:49:25      阅读:229      评论:0      收藏:0      [点我收藏+]

标签:res   plt   shel   version   .sh   显示   dirname   main   gre   

  1   
  2 # -*- coding: UTF-8 -*-
  3 """
  4 此脚本用于展示使用sklearn搭建线性回归模型
  5 """
  6 
  7 
  8 import os
  9 import sys
 10 
 11 import numpy as np
 12 import matplotlib.pyplot as plt
 13 import pandas as pd
 14 from sklearn import linear_model
 15 
 16 
 17 def evaluateModel(model, testData, features, labels):
 18     """
 19     计算线性模型的均方差和决定系数
 20     参数
 21     ----
 22     model : LinearRegression, 训练完成的线性模型
 23     testData : DataFrame,测试数据
 24     features : list[str],特征名列表
 25     labels : list[str],标签名列表
 26     返回
 27     ----
 28     error : np.float64,均方差
 29     score : np.float64,决定系数
 30     """
 31     # 均方差(The mean squared error),均方差越小越好
 32     error = np.mean(
 33         (model.predict(testData[features]) - testData[labels]) ** 2)
 34     # 决定系数(Coefficient of determination),决定系数越接近1越好
 35     score = model.score(testData[features], testData[labels])
 36     return error, score
 37 
 38 
 39 def visualizeModel(model, data, features, labels, error, score):
 40     """
 41     模型可视化
 42     """
 43     # 为在Matplotlib中显示中文,设置特殊字体
 44     plt.rcParams[font.sans-serif]=[SimHei]
 45     # 创建一个图形框
 46     fig = plt.figure(figsize=(6, 6), dpi=80)
 47     # 在图形框里只画一幅图
 48     ax = fig.add_subplot(111)
 49     # 在Matplotlib中显示中文,需要使用unicode
 50     # 在Python3中,str不需要decode
 51     if sys.version_info[0] == 3:
 52         ax.set_title(u%s % "线性回归示例")
 53     else:
 54         ax.set_title(u%s % "线性回归示例".decode("utf-8"))
 55     ax.set_xlabel($x$)
 56     ax.set_ylabel($y$)
 57     # 画点图,用蓝色圆点表示原始数据
 58     # 在Python3中,str不需要decode
 59     if sys.version_info[0] == 3:
 60         ax.scatter(data[features], data[labels], color=b,
 61             label=u%s: $y = x + \epsilon$ % "真实值")
 62     else:
 63         ax.scatter(data[features], data[labels], color=b,
 64             label=u%s: $y = x + \epsilon$ % "真实值".decode("utf-8"))
 65     # 根据截距的正负,打印不同的标签
 66     if model.intercept_ > 0:
 67         # 画线图,用红色线条表示模型结果
 68         # 在Python3中,str不需要decode
 69         if sys.version_info[0] == 3:
 70             ax.plot(data[features], model.predict(data[features]), color=r,
 71                 label=u%s: $y = %.3fx$ + %.3f 72                 % ("预测值", model.coef_, model.intercept_))
 73         else:
 74             ax.plot(data[features], model.predict(data[features]), color=r,
 75                 label=u%s: $y = %.3fx$ + %.3f 76                 % ("预测值".decode("utf-8"), model.coef_, model.intercept_))
 77     else:
 78         # 在Python3中,str不需要decode
 79         if sys.version_info[0] == 3:
 80             ax.plot(data[features], model.predict(data[features]), color=r,
 81                 label=u%s: $y = %.3fx$ - %.3f 82                 % ("预测值", model.coef_, abs(model.intercept_)))
 83         else:
 84             ax.plot(data[features], model.predict(data[features]), color=r,
 85                 label=u%s: $y = %.3fx$ - %.3f 86                 % ("预测值".decode("utf-8"), model.coef_, abs(model.intercept_)))
 87     legend = plt.legend(shadow=True)
 88     legend.get_frame().set_facecolor(#6F93AE)
 89     # 显示均方差和决定系数
 90     # 在Python3中,str不需要decode
 91     if sys.version_info[0] == 3:
 92         ax.text(0.99, 0.01, 
 93             u%s%.3f\n%s%.3f 94             % ("均方差:", error, "决定系数:", score),
 95             style=italic, verticalalignment=bottom, horizontalalignment=right,
 96             transform=ax.transAxes, color=m, fontsize=13)
 97     else:
 98          ax.text(0.99, 0.01, 
 99             u%s%.3f\n%s%.3f100             % ("均方差:".decode("utf-8"), error, "决定系数:".decode("utf-8"), score),
101             style=italic, verticalalignment=bottom, horizontalalignment=right,
102             transform=ax.transAxes, color=m, fontsize=13)
103     # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
104     # 在Python shell里面,可以设置参数"block=False",使阻断失效。
105     plt.show()
106 
107 
108 def trainModel(trainData, features, labels):
109     """
110     利用训练数据,估计模型参数
111     参数
112     ----
113     trainData : DataFrame,训练数据集,包含特征和标签
114     features : 特征名列表
115     labels : 标签名列表
116     返回
117     ----
118     model : LinearRegression, 训练好的线性模型
119     """
120     # 创建一个线性回归模型
121     model = linear_model.LinearRegression()
122     # 训练模型,估计模型参数
123     model.fit(trainData[features], trainData[labels])
124     return model
125 
126 
127 def linearModel(data):
128     """
129     线性回归模型建模步骤展示
130     参数
131     ----
132     data : DataFrame,建模数据
133     """
134     features = ["x"]
135     labels = ["y"]
136     # 划分训练集和测试集
137     trainData = data[:15]
138     testData = data[15:]
139     # 产生并训练模型
140     model = trainModel(trainData, features, labels)
141     # 评价模型效果
142     error, score = evaluateModel(model, testData, features, labels)
143     # 图形化模型结果
144     visualizeModel(model, data, features, labels, error, score)
145 
146 
147 def readData(path):
148     """
149     使用pandas读取数据
150     """
151     data = pd.read_csv(path)
152     return data
153 
154 
155 if __name__ == "__main__":
156     homePath = os.path.dirname(os.path.abspath(__file__))
157     # Windows下的存储路径与Linux并不相同
158     if os.name == "nt":
159         dataPath = "%s\\data\\simple_example.csv" % homePath
160     else:
161         dataPath = "%s/data/simple_example.csv" % homePath
162     data = readData(dataPath)
163     linearModel(data)
164 © 2019 GitHub, Inc.
165 Terms
166 Privacy
167 Security
168 Status
169 Help
170 Contact GitHub
171 Pricing
172 API
173 Training
174 Blog
175 About

 

线性回归

标签:res   plt   shel   version   .sh   显示   dirname   main   gre   

原文地址:https://www.cnblogs.com/bbgoal/p/10793527.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!