码迷,mamicode.com
首页 > 编程语言 > 详细

python3-梯度下降法

时间:2018-12-12 17:38:25      阅读:180      评论:0      收藏:0      [点我收藏+]

标签:2.0   优化   随机梯度   erro   模型   梯度   随机   proc   -o   

梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。在求解损失函数的最小值时,可以通过梯度下降法来一步步的迭代求解,得到最小化的损失函数和模型参数值。反过来,如果我们需要求解损失函数的最大值,这时就需要用梯度上升法来迭代了。在机器学习中,基于基本的梯度下降法发展了两种梯度下降方法,分别为随机梯度下降法和批量梯度下降法。

技术分享图片
求解无约束优化问题最常用的方法。

import math
x=[(1,3.,5.),(10,12.,8.),(4.,8.,2.)]#x0,x1,x2样本属性
y=[20,62,36]#样本点对应的输出
epsilon=0.001#迭代终止阀值
alpha=0.01#学习率
diff=0
max_itor=1000
error1=0
error0=0
cnt=0

#初始化参数
theta0=0
theta1=0
theta2=0
#y=theta0*x[0]+theta1*x[1]+theta2*x[2]
while True:
    cnt+=1
    #迭代计算
    i=0
    for x_sample in x:
        #残差
        diff=(theta0*x_sample[0]+theta1*x_sample[1]+theta2*x_sample[2])-y[i]
        #梯度-=学习率*残差*样本x
        theta0-=alpha*diff*x_sample[0]
        theta1-=alpha*diff*x_sample[1]        
        theta2-=alpha*diff*x_sample[2] 
        i+=1
    #损失函数    
    erro1=0
    i=0
    for x_sample in x:
        error1+=math.sqrt((y[i]-(theta0*x_sample[0]+theta1*x_sample[1]+theta2*x_sample[2]))**2)
        i+=1
    print("迭代次数:{}".format(cnt))
    if abs(error1-error0)<epsilon or cnt>max_itor:
        break
    else:
        error0=error1

    print("theta0:{},theta1:{},theta2:{},error1:{}".format(theta0,theta1,theta2,error1))
print("最终参数:theta0:{},theta1:{},theta2:{},error1:{}".format(theta0,theta1,theta2,error1))
print("迭代次数:{},{}".format(cnt,abs(error1-error0)))    

计算结果如下:
迭代次数:1
theta0:3.09216,theta1:2.8003199999999993,theta2:3.7900799999999997,error1:49.64095999999997
迭代次数:2
theta0:1.9569175142400004,theta1:2.33457205248,theta2:2.0365607731200006,error1:57.79668039679997
迭代次数:3
theta0:2.066932417249936,theta1:2.7722210571318064,theta2:2.051718445195592,error1:68.23965761462367
迭代次数:4
theta0:1.822654631584629,theta1:2.901039039580537,theta2:1.6667578137327976,error1:75.92072762586017
迭代次数:5
theta0:1.693555373168776,theta1:3.08255331426641,theta2:1.4929122977800056,error1:82.96397010375048
迭代次数:6
theta0:1.5430179742082557,theta1:3.216662654931747,theta2:1.332167927092537,error1:89.22789055681605
迭代次数:7
theta0:1.4113694328772513,theta1:3.3374805424421083,theta2:1.2269255708571698,error1:94.84914383450159
迭代次数:8
theta0:1.2868249555339966,theta1:3.439095833688646,theta2:1.1521653428792904,error1:99.87452978873952
迭代次数:9
theta0:1.1726084671962476,theta1:3.5266834345496756,theta2:1.105008597548144,error1:104.36926415681856
迭代次数:10
theta0:1.067376553471004,theta1:3.601645081878111,theta2:1.078070126912359,error1:108.38586213296406
迭代次数:11
theta0:0.9710007419047019,theta1:3.666041061576236,theta2:1.0668970764264119,error1:111.97380398049023
迭代次数:12
theta0:0.8829168334446961,theta1:3.7213629726753563,theta2:1.0674955203869176,error1:115.17724664077599
迭代次数:13
theta0:0.8026402272040402,theta1:3.76895152551743,theta2:1.0768104917918095,error1:118.29960076893735
迭代次数:14
theta0:0.7296398375359576,theta1:3.8099235828860167,theta2:1.0923850541361118,error1:121.34098498747792
迭代次数:15
theta0:0.663393710857731,theta1:3.8452350082543014,theta2:1.1122954882525868,error1:124.2652418109742
迭代次数:16
theta0:0.6033866679587384,theta1:3.8756972088098616,theta2:1.1350352707974372,error1:127.04787692700475
迭代次数:17
theta0:0.5491211044939166,theta1:3.9020017237230458,theta2:1.1594384592656597,error1:129.67344382498106
迭代次数:18
theta0:0.5001209350889427,theta1:3.9247378337394596,theta2:1.1846116758687089,error1:132.13348448460872
迭代次数:19
theta0:0.4559351887528604,theta1:3.9444083282110154,theta2:1.2098801577587925,error1:134.42486273808058
迭代次数:20
theta0:0.4161398760097775,theta1:3.961442545955566,theta2:1.2347433382694637,error1:136.54843596691396
迭代次数:21
theta0:0.3803389539799083,theta1:3.9762074679418147,theta2:1.258838859982624,error1:138.50799699481425
迭代次数:22
theta0:0.34816448638576875,theta1:3.989017069722269,theta2:1.2819133365085038,error1:140.309437047788
迭代次数:23
theta0:0.31927624063903204,theta1:4.000140238242338,theta2:1.3037987001979212,error1:141.96008745565993
迭代次数:24
theta0:0.2933608661393517,theta1:4.009807470632691,theta2:1.3243931101796003,error1:143.46820598632377
迭代次数:25
theta0:0.2701307815071949,theta1:4.018216548882075,theta2:1.343645594635283,error1:144.84257971989888
迭代次数:26
theta0:0.24932286674293097,theta1:4.025537350027884,theta2:1.3615437411854452,error1:146.0922215503319
迭代次数:27
theta0:0.23069703607679862,theta1:4.031915927002418,theta2:1.3781038723780814,error1:147.2261416024531
迭代次数:28
theta0:0.21403474942393644,theta1:4.037477973499238,theta2:1.3933632429006912,error1:148.25317832710124
迭代次数:29
theta0:0.19913750650516404,theta1:4.042331768323853,theta2:1.4073738780772709,error1:149.1818768869161
迭代次数:30
theta0:0.1858253564988222,theta1:4.046570679568842,theta2:1.4201977415116473,error1:150.02040478667567
迭代次数:31
theta0:0.17393544723201163,theta1:4.050275296290051,theta2:1.4319029761479558,error1:150.77649662130497
迭代次数:32
theta0:0.16332063091020366,theta1:4.053515244718039,theta2:1.4425610095141552,error1:151.4574213864247
迭代次数:33
theta0:0.15384813788631874,theta1:4.0563507371011625,theta2:1.4522443522246664,error1:152.06996708143316
迭代次数:34
theta0:0.14539832568918845,theta1:4.0588338937641835,theta2:1.4610249503518553,error1:152.62043838424128
迭代次数:35
theta0:0.13786350723118565,theta1:4.0610098726495405,theta2:1.4689729782087184,error1:153.11466403170328
迭代次数:36
theta0:0.1311468596017076,theta1:4.062917835294268,theta2:1.4761559793892074,error1:153.55801123507138
迭代次数:37
theta0:0.12516141296908093,theta1:4.0645917737224435,theta2:1.4826382813943688,error1:153.95540502400343
迭代次数:38
theta0:0.11982911772990999,theta1:4.066061218965977,theta2:1.4884806235000523,error1:154.31135086933654
迭代次数:39
theta0:0.11507998705715922,theta1:4.067351848752081,theta2:1.4937399492480798,error1:154.62995930346008
迭代次数:40
theta0:0.11085131132191628,theta1:4.06848600921925,theta2:1.4984693245255896,error1:154.91497155370473
迭代次数:41
theta0:0.10708694043012688,theta1:4.069483163265382,theta2:1.5027179500151824,error1:155.1697854419213
迭代次数:42
theta0:0.10373662986968246,theta1:4.070360276225264,theta2:1.506531243164939,error1:155.3974809932209
迭代次数:43
theta0:0.100755446160541,theta1:4.071132147964172,theta2:1.5099509700008757,error1:155.6008453476237
迭代次数:44
theta0:0.09810322740581104,theta1:4.071811699112785,theta2:1.5130154112989336,error1:155.7823966874585
迭代次数:45
theta0:0.09574409472656342,theta1:4.07241021801691,theta2:1.5157595510254342,error1:155.9444069867896
迭代次数:46
theta0:0.09364601050525476,theta1:4.072937574000383,theta2:1.5182152776896534,error1:156.08892346185718
迭代次数:47
theta0:0.09178037954439187,theta1:4.07340240171338,theta2:1.5204115914496612,error1:156.21778865753868
迭代次数:48
theta0:0.09012168945473602,theta1:4.073812260637902,theta2:1.52237481157167,error1:156.33265914747082
迭代次数:49
theta0:0.08864718681018806,theta1:4.074173773227778,theta2:1.5241287802449757,error1:156.43502285740448
迭代次数:50
theta0:0.087336585836445,theta1:4.074492744655682,theta2:1.5256950598660606,error1:156.5262150447862
迭代次数:51
theta0:0.08617180663132865,theta1:4.074774266710597,theta2:1.5270931217813,error1:156.60743298425302
迭代次数:52
theta0:0.08513674014179573,theta1:4.075022808024034,theta2:1.5283405251629514,error1:156.67974942012904
迭代次数:53
theta0:0.08421703734278539,theta1:4.075242292492553,theta2:1.5294530852245967,error1:156.7441248542896
迭代次数:54
theta0:0.0833×××027384308,theta1:4.075436167499148,theta2:1.5304450303899535,error1:156.80141874185225
迭代次数:55
theta0:0.08267401278941551,theta1:4.0756074633101,theta2:1.5313291483376608,error1:156.85239966881377
迭代次数:56
theta0:0.08202918906685645,theta1:4.075758844830961,theta2:1.5321169210742478,error1:156.8977545855851
迭代次数:57
theta0:0.08145643809209843,theta1:4.07589265674041,theta2:1.5328186493543716,error1:156.9380971688586
迭代次数:58
theta0:0.08094774250643422,theta1:4.076010962879638,theta2:1.5334435668846778,error1:156.97397538176028
迭代次数:59
theta0:0.08049597034921775,theta1:4.076115580654183,theta2:1.533999944826227,error1:157.00587829908238
迭代次数:60
theta0:0.08009477837064638,theta1:4.076208111101516,theta2:1.5344951871586714,error1:157.03424226080773
迭代次数:61
theta0:0.07973852571685477,theta1:4.076289965188932,theta2:1.5349359174945585,error1:157.0594564132911
迭代次数:62
theta0:0.07942219690669175,theta1:4.076362386829975,theta2:1.5353280579395872,error1:157.0818676935075
迭代次数:63
theta0:0.07914133312650752,theta1:4.07642647304202,theta2:1.5356769005889803,error1:157.10178530780868
迭代次数:64
theta0:0.0788919709666837,theta1:4.076483191611305,theta2:1.5359871722348422,error1:157.11948475272578
迭代次数:65
theta0:0.07867058781206858,theta1:4.076533396582959,theta2:1.5362630928371812,error1:157.135211421579
迭代次数:66
theta0:0.07847405317877482,theta1:4.076577841851811,theta2:1.5365084282845856,error1:157.14918383703696
迭代次数:67
theta0:0.07829958536236079,theta1:4.076617193093495,theta2:1.5367265379407764,error1:157.16159654633952
迭代次数:68
theta0:0.07814471282808945,theta1:4.076652038244171,theta2:1.5369204174420308,error1:157.1726227126718
迭代次数:69
theta0:0.07800723983316768,theta1:4.076682896710167,theta2:1.5370927371785557,error1:157.18241643316208
迭代次数:70
theta0:0.07788521582426941,theta1:4.076710227465454,theta2:1.537245876861199,error1:157.1911148111746
迭代次数:71
theta0:0.07777690820168628,theta1:4.076734436174608,theta2:1.5373819565438644,error1:157.19883980797536
迭代次数:72
theta0:0.07768077808467785,theta1:4.076755881461383,theta2:1.5375028644421276,error1:157.20569989646145
迭代次数:73
theta0:0.07759545875140475,theta1:4.076774880427727,theta2:1.5376102818600335,error1:157.21179153745214
迭代次数:74
theta0:0.07751973646167826,theta1:4.076791713514859,theta2:1.5377057055101315,error1:157.21720049703328
迭代次数:75
theta0:0.07745253340201806,theta1:4.076806628786545,theta2:1.5377904674865241,error1:157.22200302161463
迭代次数:76
theta0:0.07739289252046681,theta1:4.076819845704606,theta2:1.5378657531270912,error1:157.2262668856953
迭代次数:77
theta0:0.07733996404376112,theta1:4.076831558458074,theta2:1.537932616979259,error1:157.23005232581428
迭代次数:78
theta0:0.0772929934918246,theta1:4.076841938899721,theta2:1.5379919970634,error1:157.23341287279237
迭代次数:79
theta0:0.0772513110246509,theta1:4.076851139137111,theta2:1.5380447276093971,error1:157.2363960931281
迭代次数:80
theta0:0.0772143219745971,theta1:4.076859293819508,theta2:1.5380915504248862,error1:157.23904424928713
迭代次数:81
theta0:0.07718149843310078,theta1:4.076866522156913,theta2:1.5381331250380583,error1:157.24139488761074
迭代次数:82
theta0:0.07715237177516868,theta1:4.0768729297031046,theta2:1.538170037743732,error1:157.24348136165582
迭代次数:83
theta0:0.07712652601776625,theta1:4.0768786099306755,theta2:1.5382028096684863,error1:157.2453332979565
迭代次数:84
theta0:0.0771035919196079,theta1:4.076883645622657,theta2:1.5382319039588743,error1:157.24697701045818
迭代次数:85
theta0:0.07708324174004459,theta1:4.076888110102388,theta2:1.5382577321861508,error1:157.24843586921034
迭代次数:86
theta0:0.07706518458378205,theta1:4.076892068320665,theta2:1.5382806600512935,error1:157.2497306283089
迭代次数:87
theta0:0.07704916226627687,theta1:4.076895577816928,theta2:1.5383010124654535,error1:157.2508797175436
迭代次数:88
theta0:0.07703494564181095,theta1:4.076898689569265,theta2:1.5383190780730835,error1:157.2518995017271
迭代次数:89
最终参数:theta0:0.07702233134268946,theta1:4.076901448746194,theta2:1.5383351132779755,error1:157.25280451125468
迭代次数:89,0.0009050095275711101

python3-梯度下降法

标签:2.0   优化   随机梯度   erro   模型   梯度   随机   proc   -o   

原文地址:http://blog.51cto.com/13959448/2329456

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!