码迷,mamicode.com
首页 > 其他好文 > 详细

【NLP】RNN、LSTM、GRU

时间:2020-07-17 19:40:13      阅读:72      评论:0      收藏:0      [点我收藏+]

标签:解码   mda   ash   because   dash   状态   block   log   detail   

RNN

循环神经网络。像之前的CNN只能处理单独的输入,前一个输入与后一个输入没有关系。但例如NLP中,我们需要前后文的信息。所以RNN应运而生。

标准的RNN中,1)N input -- N output  2)权值共享,W、U、V每个都是一样的。

实际中,这一种结构无法解决所有问题。所以也有了以下变形:

1)输入序列 N,输出一个。例如文本情感分类

技术图片

2)输入 序列M,输出序列N,不等长。

这种结构又叫做Encoder-Decoder,也可称为Seq2Seq模型。其结构原理是先编码后解码。左侧的RNN用来编码得到 C,再由右侧的RNN对 C 进行解码。

技术图片

其中得到 C 的方法有很多种:

技术图片

RNN的训练方法——BPTT

BPTT 即 Back-propagation through time 基于时间反向传播,本质还是BP算法,只不过要基于时间反向传播。

在最朴素的RNN,也即 N - N 的场景下:

技术图片

技术图片

其中 X 是输入序列,S 是状态序列(记忆),X 是输入序列, O 是输出序列。下文中 O 以 E 代替,因为参考的博客中写的是E,方便截图公式 :P

技术图片

RNN的公式如下:

技术图片

 Z   =   V * St

 技术图片

因为RNN中损失是累加的,所以总损失需要求和。

 Just like we sum up the errors, we also sum up the gradients at each time step for one training example

例如:

技术图片

 以第 3 时刻为例,对 V 求导(最简单,因为只依赖当前时刻,求导到 Z 就结束了,没有到 S)。其中 Z =   V * S3

技术图片

 对 W 求导(与对 U 求导类似),他们都依赖于前面的时刻,因为需要经过 S:

技术图片

 然而 S3 又依赖于前一时刻的 S2 和 W,不能把 S3 当成简单的常数看待,需要继续打开链式法则:

技术图片

于是展开求和得:

技术图片

We sum up the contributions of each time step to the gradient. In other words, because 技术图片 is used in every step up to the output we care about, we need to backpropagate gradients from 技术图片 through the network all the way to 技术图片:

每个时间步长,W 都有贡献,所以 加和起来

技术图片

和传统神经网络的主要区别就是,在RNN中,需要把每个时间步长的 W 加和起来。

在传统神经网络中,我们不会跨 layer 共享参数,所以求导时不需要做加和操作。

通俗来讲,BPTT其实就是在展开的RNN上进行传统的反向传播。

技术图片是对 S 自己的连式法则,例如:技术图片

 

所以对 W 求导的公式可重写成:

技术图片

 

对 V 求导同理。

 

RNN梯度消失的问题:

http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

 

 

 

 

RNN,LSTM, GRU(需要参数量少), BiLSTM区别

LSTM为什么可以解决RNN中梯度消失的问题

因为把RNN中连乘,由forget gate转换变成了引入加号。乘法变成了加法,RNN中memory里的值总是被覆盖,而LSTM是memory和input相加,乘以一个数相加。除非forget gate被打开。

 

 

参考摘录:

【1】https://zhuanlan.zhihu.com/p/30844905

【2】https://blog.csdn.net/zhaojc1995/article/details/80572098

【3】 http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

 

【NLP】RNN、LSTM、GRU

标签:解码   mda   ash   because   dash   状态   block   log   detail   

原文地址:https://www.cnblogs.com/YeZzz/p/13172417.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!