码迷,mamicode.com
首页 > 其他好文 > 详细

线性回归的pytorch代码

时间:2020-07-04 22:22:52      阅读:74      评论:0      收藏:0      [点我收藏+]

标签:isp   推荐   epo   nta   str   结束   off   base64   c4c   

使用pytorch实现的线性回归, 闲言少叙,直接上代码,客官请看:

 1 import torch
 2 import torch.nn as  nn
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 
 6 #设置相关参数
 7 input_size=1
 8 output_size=1
 9 num_epochs=60
10 learning_rate=0.001
11 
12 #导入训练数据集
13 x_train=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],
14                   [9.779],[6.182],[7.59],[2.167],[7.042],
15                   [10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
16 y_train=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],
17                   [3.366],[2.596],[2.53],[1.221],[2.827],
18                   [3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
19 
20 #设置模型、损失函数、优化函数等
21 model=nn.Linear(input_size,output_size)
22 criterion=nn.MSELoss()
23 optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
24 
25 #开始迭代训练
26 for epoch in  range(num_epochs):
27     inputs=torch.from_numpy(x_train)
28     targets=torch.from_numpy(y_train)
29 
30     outputs=model(inputs)
31     loss=criterion(outputs,targets)
32 
33     optimizer.zero_grad()
34     loss.backward()
35     optimizer.step()
36 
37     if (epoch+1) %5==0:
38         print("Epoch [{}/{}], LossL:{:.4f}".format(epoch+1,num_epochs,loss.item()))
39 #计算出训练之后的期望值/预测值,并与实际值进行画图比较
40 predicted=model(torch.from_numpy(x_train)).detach().numpy()
41 plt.plot(x_train,y_train,ro,label=Original Data)
42 plt.plot(x_train,predicted,label=Fitted Line)
43 plt.legend()
44 plt.show()
45 #保存模型相关数据
46 torch.save(model.state_dict(),model.ckpt

 

-------------------- 正文到此结束------------------------

推荐一个公众号:健哥聊量化,会持续推出股票相关基础知识,以及python实现的一些基本的分析代码。欢迎大家关注,二维码如下:

技术图片技术图片?

相关文章列表如下:

?

线性回归的pytorch代码

标签:isp   推荐   epo   nta   str   结束   off   base64   c4c   

原文地址:https://www.cnblogs.com/dataat/p/13236836.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!