码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow-tf.while_loop

时间:2018-12-23 12:01:35      阅读:482      评论:0      收藏:0      [点我收藏+]

标签:嵌套   env   dsl   除了   python2   lse   特定   用户控制   pytho   

tf.while_loop(
? ? cond,
? ? body,
? ? loop_vars,
? ? shape_invariants=None,
? ? parallel_iterations=10,
? ? back_prop=True,
? ? swap_memory=False,
? ? name=None,
? ? maximum_iterations=None,
? ? return_same_structure=False
)

当条件谓词cond为True,重复body

cond是一个调用,返回一个boolean的标量tensor,body是一个调用,返回一个(可能嵌套的)元组。tensors的命名元组或列表与loop_vars有相同数量(长度和结构)和类型,loop_vars是一个(可能嵌套的)元祖。tensors的命名元组或列表传送到cond和body中。cond和body都采用和loop_vars一样多的参数。

除了常规张量或索引片断外,body还可以接受和返回TensorArray对象。TensorArray的流动将在循环期间和梯度计算期间适当地传送。

while_loop调用cond和body刚好一次(在?while_loop内部,而不是 Session.run()间。while_loop将cond和body调用期间创建的计算图片段与一些附加的图节点缝合在一起,重复body直到cond返回false。

为了正确起见,tf.while_loop()?严格执行循环变量的形状不变量。形状不变量是一个(可能是部分的)形状,在循环的迭代过程中是不变的。如果迭代后的循环变量的形状被确定为比其形状不变量更通用或不兼容,则会产生错误。例如,[11, None]的形状比[11, 17 ]的形状更一般,并且[11, 21 ]与[11, 17 ]不兼容。默认情况下(如果没有指定参数shape_invariants),则假定loop_vars中的每个张量的初始形状在每次迭代中都是相同的。?shape_invariants参数允许调用者为每个循环变量指定一个不太特定的形状不变量,如果形状在迭代之间发生变化,则需要使用该形状不变量。还可以在body函数中使用tf.Tensor.set_shape函数来指示输出循环变量具有特定的形状。SparseTensor和索引片断的形状不变量被特别地处理如下:

a)如果循环变量是稀疏张量(SparseTensor),则形状不变量必须是张量形状(TensorShape[r]),其中r是由稀疏张量表示的稠密张量的秩。这意味着闪光灯的三个张量的形状是([None], [None, r], [r])。注意:这里的形状不变量是SparseTensor.dense_shape形状属性的形状。这必须是一个vector的形状。

b)如果循环变量是IndexedSlices,则形状不变量必须是IndexedSlices的值张量的形状不变量。这意味着索引数组(IndexedSlice)的三个张量的形状是(shape, [shape[0]], [shape.ndims])

while_loop实现非严格语义,使多个迭代并行运行。并行迭代的最大次数可以由parallel_iterations控制,这允许用户控制内存消耗和执行顺序。对于正确的程序,HyyLoad应该为任何parallel_iterations>0返回相同的结果。

对于训练,TensorFlow存储张量(tensors),这些张量在前向传递和反射传播中产生,这些张量是内存消耗的主要来源,并且经常在GPU上训练时导致OOM错误。当标志swap_memory为真时,我们将这些张量从GPU交换到CPU。例如,这允许我们训练具有很长序列和大批量的RNN模型。

参数:

cond:可调用的表示循环终止条件。?
body: 表示循环体的可调用体。
loop_vars: 一个(可能嵌套)元组、namedtuple或numpy数组列表、张量、和TensorArray?对象
shape_invariants: 循环变量的形状不变量。
parallel_iterations: 允许并行运行的迭代次数。它必须是正整数。
back_prop: 是否在这个while循环中启用了backprop。
swap_memory: 是否为这个循环启用了GPU-CPU内存交换。
name: 返回的张量的可选名称前缀。
maximum_iterations: 可选while循环运行的最大迭代次数.如果提供,则cond输出与确保执行的迭代数量不大于maximum_iterations?
return_same_structure: 如果True, 输出具有与loop_vars相同的结构。如果启用了紧急执行,则忽略此操作(并且始终被视为true)。
返回:

循环后的循环变量的输出张量。如果return_same_structure是True,返回值的结构与loop_vars相同。如果return_same_structure是False,则返回的是一个Tensor、TensorArray或IndexedSlice,否则为列表。

Raises:

TypeError:如果cond或body不能调用
ValueError: 如果loop_vars为空

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 27 11:16:32 2018
@author: myhaspl
"""

import tensorflow as tf
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
sess=tf.Session()
with sess:
    print sess.run(r)

运行结果为10,函数b的最后运行结果。

循环将i每次增加1,直到10

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 27 11:16:32 2018
@author: myhaspl
"""

import tensorflow as tf
i = tf.constant(100)
def b(i):
    res=tf.subtract(i, 2)
    return res

c = lambda i: tf.greater(i, 0)

r = tf.while_loop(c, b, [i])

sess=tf.Session()
with sess:
    print sess.run(r)

循环每次将i减1,直到0

运行结果为0,函数b的最后运行结果。

tensorflow-tf.while_loop

标签:嵌套   env   dsl   除了   python2   lse   特定   用户控制   pytho   

原文地址:http://blog.51cto.com/13959448/2334033

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!