码迷,mamicode.com
首页 > 其他好文 > 详细

tf.nn.rnn_cell.MultiRNNCell

时间:2019-03-13 12:39:36      阅读:243      评论:0      收藏:0      [点我收藏+]

标签:unit   pre   __init__   图片   type   var   dom   output   strong   

  • Class tf.contrib.rnn.MultiRNNCell
  • Class tf.nn.rnn_cell.MultiRNNCell

构建多隐层神经网络

__init__(cells, state_is_tuple=True)

cells:rnn cell 的list

state_is_tuple:true,状态Ct和ht就是分开记录,放在一个tuple中,接受和返回的states是n-tuples,其中n=len(cells),False,states是concatenated沿着列轴.后者即将弃用。

 

BasicLSTMCell 单隐层

技术图片

 

BasicLSTMCell 多隐层

技术图片

 

代码示例

# encoding:utf-8
import tensorflow as tf

batch_size=10
depth=128

inputs=tf.Variable(tf.random_normal([batch_size,depth]))

previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100]))
previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200]))
previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300]))

num_units=[100,200,300]
print(inputs)

cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units]
mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells)

outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2))

print(outputs.shape) #(10, 300)
print(states[0]) #第一层LSTM
print(states[1]) #第二层LSTM
print(states[2]) ##第三层LSTM
print(states[0].h.shape) #第一层LSTM的h状态,(10, 100)
print(states[0].c.shape) #第一层LSTM的c状态,(10, 100)
print(states[1].h.shape) #第二层LSTM的h状态,(10, 200)

输出

(10, 300)
LSTMStateTuple(c=<tf.Tensor multi_rnn_cell/cell_0/basic_lstm_cell/Add_1:0 shape=(10, 100) dtype=float32>, h=<tf.Tensor multi_rnn_cell/cell_0/basic_lstm_cell/Mul_2:0 shape=(10, 100) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor multi_rnn_cell/cell_1/basic_lstm_cell/Add_1:0 shape=(10, 200) dtype=float32>, h=<tf.Tensor multi_rnn_cell/cell_1/basic_lstm_cell/Mul_2:0 shape=(10, 200) dtype=float32>)
LSTMStateTuple(c=<tf.Tensor multi_rnn_cell/cell_2/basic_lstm_cell/Add_1:0 shape=(10, 300) dtype=float32>, h=<tf.Tensor multi_rnn_cell/cell_2/basic_lstm_cell/Mul_2:0 shape=(10, 300) dtype=float32>)
(10, 100)
(10, 100)
(10, 200)

 

tf.nn.rnn_cell.MultiRNNCell

标签:unit   pre   __init__   图片   type   var   dom   output   strong   

原文地址:https://www.cnblogs.com/yanshw/p/10515436.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!