码迷,mamicode.com
首页 > 其他好文 > 详细

78、tensorflow滑动平均模型,用来更新迭代的衰减系数

时间:2017-04-21 09:40:28      阅读:231      评论:0      收藏:0      [点我收藏+]

标签:count   code   ble   port   dict   动态控制   add   sam   app   

‘‘‘
Created on 2017年4月21日

@author: weizhen
‘‘‘
#4、滑动平均模型
import tensorflow as tf
#定义一个变量用于计算滑动平均,这个变量的初始值为0.
#类型为tf.float32,因为所有需要计算滑动平均的变量必须是实数型
v1=tf.Variable(0,dtype=tf.float32)
#这里step变量模拟神经网络中迭代的轮数,可以用于动态控制衰减率
step=tf.Variable(0,trainable=False)
#定义一个滑动平均的类(class)。初始化时给定了衰减率(0.99)和控制衰减率的变量step
ema=tf.train.ExponentialMovingAverage(0.99,step)
#定义一个更新变量滑动平均的操作。这里需要给定一个列表,每次执行这个操作时
#这个列表中的变量都会被更新
maintain_averages_op=ema.apply([v1])
with tf.Session() as sess:
    #初始化所有变量
    init_op=tf.global_variables_initializer()
    sess.run(init_op)
    
    #通过ema.average(v1)获取滑动平均之后变量的取值。在初始化之后变量v1的值和v1的滑动平均都为0
    print(sess.run([v1,ema.average(v1)]))   #输出[0.0,0.0]
    
    #更新变量v1的值到5
    sess.run(tf.assign(v1,5))
    #更新v1的滑动平均值。衰减率为min{0.99,(1+step)/(10+step)=0.1}=0.1
    #所以v1的滑动平均会被更新为0.1*0+0.9*5=4.5
    sess.run(maintain_averages_op)
    print(sess.run([v1,ema.average(v1)]))
    
    #更新step的值为10000
    sess.run(tf.assign(step,10000))
    #更新v1的值为10
    sess.run(tf.assign(v1,10))
    #更新v1的滑动平均值。衰减率为min{0.99,(1+step)/(10+step)=0.999}}=0.99
    #所以v1的滑动平均会被更新为0.99*4.5+0.01*10=4.555
    sess.run(maintain_averages_op)
    print(sess.run([v1,ema.average(v1)]))
    #输出[10.0,4.5549998]
    #再次更新滑动平均值,得到的新滑动平均值为0.99*4.555+0.01*10=4.60945
    sess.run(maintain_averages_op)
    print(sess.run([v1,ema.average(v1)]))
    #输出[10.0,4.6094499]

输出的结果如下所示

E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "BestSplits" device_type: "CPU") for unknown op: BestSplits
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "CountExtremelyRandomStats" device_type: "CPU") for unknown op: CountExtremelyRandomStats
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "FinishedNodes" device_type: "CPU") for unknown op: FinishedNodes
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "GrowTree" device_type: "CPU") for unknown op: GrowTree
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "ReinterpretStringToFloat" device_type: "CPU") for unknown op: ReinterpretStringToFloat
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "SampleInputs" device_type: "CPU") for unknown op: SampleInputs
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "ScatterAddNdim" device_type: "CPU") for unknown op: ScatterAddNdim
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "TopNInsert" device_type: "CPU") for unknown op: TopNInsert
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "TopNRemove" device_type: "CPU") for unknown op: TopNRemove
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "TreePredictions" device_type: "CPU") for unknown op: TreePredictions
E c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:943] OpKernel (op: "UpdateFertileSlots" device_type: "CPU") for unknown op: UpdateFertileSlots
[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.5549998]
[10.0, 4.6094499]

 

78、tensorflow滑动平均模型,用来更新迭代的衰减系数

标签:count   code   ble   port   dict   动态控制   add   sam   app   

原文地址:http://www.cnblogs.com/weizhen/p/6741705.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!