RNN训练算法BPTT介绍

RNN训练算法BPTT介绍 本篇文章第一部分翻译自:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/,英文好的朋友可以直接看原文。最近看到RNN,先是困惑于怎样实现隐藏层的互联,搞明白之后又不太明白如何使用BPTT进…

大家好,又见面了,我是你们的朋友全栈君。

 

本篇文章第一部分翻译自:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/,英文好的朋友可以直接看原文。

最近看到RNN,先是困惑于怎样实现隐藏层的互联,搞明白之后又不太明白如何使用BPTT进行训练,在网上找资源发现本篇博客介绍较为详细易懂,自己翻译了一遍,以下:

RNN教程,第3部分,通过时间反向传播(BPTT)和梯度消失

这是RNN教程的第三部分。

在本教程的前面部分我们从头开始实现了一个RNN网络,但是没有探究实现BPTT计算梯度的细节。在这部分我们将给出BPTT的简要概述并且解释它和传统反向传播算法的区别。随后我们将致力于理解梯度消失问题(vanishing gradient problem),这个问题促成了LSTMs和GRUs的发展,在NLP(和其他领域),它们是当前最受欢迎和最为强大的模型中的两种。梯度消失问题最早于1991年有Sepp Hochreiter发现,最近由于深度结构的应用增多而重新受到关注。

如果想完全理解这部分内容,我建议你对偏导数和基本的反向传播工作很熟悉。如果你还不熟悉,你可以从【文中提供了三个地址】找到好的教程,它们随着难度的上升而排序。

Backpropagation Through TIme(BPTT)

我们先快速回顾一下RNN的等式。注意到这里有一个小变化,符号o变成了\widehat{y}。这是为了和我参考的一些文献保持一致。

RNN训练算法BPTT介绍

我们同时定义我们的损失函数(或者称为误差)为交叉熵损失,由以下公式给出:

RNN训练算法BPTT介绍

这里{y}_t是t时刻的正确单词,\widehat{y}_t是网络的预测。典型的,我们将完整的序列(句子)是为一个训练实例,所以总的误差是各个时间点(单词)误差的和。

RNN训练算法BPTT介绍

我们的目的是计算误差关于参数U,V和W的梯度并通过随机梯度下降(SGD)来学习好的参数。正如我们计算了误差的和,我们也将一个训练实例各个时间地啊你的梯度做一个求和:\frac{\partial E}{\partial W} = \sum\limits_{t} \frac{\partial E_t}{\partial W} 。

我们使用链式求导来计算这些导数。这是从误差开始后应用反向传播算法。在这篇文章的剩余部分我们将使用 E_3作为例子,这只是为了用一个实际的数来做推导。

\begin{aligned}  \frac{\partial E_3}{\partial V} &=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial V}\\  &=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial z_3}\frac{\partial z_3}{\partial V}\\  &=(\hat{y}_3 - y_3) \otimes s_3 \\  \end{aligned}

在上面的式子中,z_3 =Vs_3,同时\otimes表示两个向量的外积运算。如果上面讲的你跟不上也不用担心,我跳过了一些步骤,你可以自己尝试计算这些导数(这是一个很好的锻炼!)。我想从上面式子中得到的是\frac{\partial E_3}{\partial V}的计算仅仅依赖于当前时间点的数值\hat{y}_3, y_3, s_3。如果你掌握着这些,计算误差关于V的导数就仅仅是一个简单的矩阵乘法。

但是对于\frac{\partial E_3}{\partial W}(和U)的情况却是不同的。我们列出链式法则来一探究竟,与上面类似:

\begin{aligned}  \frac{\partial E_3}{\partial W} &= \frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial s_3}\frac{\partial s_3}{\partial W}\\  \end{aligned}

现在应该注意到的是s_3 = \tanh(Ux_t + Ws_2)依赖于s_2,而s_2又依赖于W和s_1,以此类推。如果我们计算关于W的导数我们不能简单地将s_2视为常量!我们需要再次使用链式法则,我们最终获得的表达式为:

\begin{aligned}  \frac{\partial E_3}{\partial W} &= \sum\limits_{k=0}^{3} \frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial W}\\  \end{aligned}

我们将每个时间点对梯度的贡献求和。话句话说,由于在到达我们所关心的输出的过程中的每一步计算中都用了W,我们需要从t=3开始在网络中的每一个路径反向传播梯度直到t=0。

RNN训练算法BPTT介绍

注意到这和我们在深度前向神经网络中使用的标准反向传播算法是一样的。最主要的区别在于我们计算了关于W每个时间点上的梯度并将它们求和。传统神经网络中我们不会在层间分享参数,所以也不用做任何求和。但是在我看来BPTT不过是标准反向传播在没展开的RNN的一个有趣的名字。类似反向传播你可以定义一个向后传播的δ矢量,例如:\delta_2^{(3)} = \frac{\partial E_3}{\partial z_2} =\frac{\partial E_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial z_2},这里z_2 = Ux_2+ Ws_1。然后应用相同的方程式。

一个简单的BPTT实现类似于下面的代码:

RNN训练算法BPTT介绍

翻译结束,原文后续部分探讨梯度消失。

推到一下上面的公式:

RNN训练算法BPTT介绍

部分参考;

http://blog.sina.cn/dpool/blog/s/blog_6e32babb0102y3u7.html

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152328.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • dev-c++ 总是坏

    dev-c++ 总是坏

  • word去掉万恶的域代码

    word去掉万恶的域代码背景:写论文使用mathtype插入公式后,有时候会显示域代码更新软件环境:macos10.14.6,word16.45解决过程:刚开始在域代码那里接受修订,然后重新打开word还是现实域代码,去网上找了教程(mac下)先选中,然后command+shift+F9解决了问题。如果是windows应该是ctrl+shift+F9(我没试过哈,猜的)…

  • 慧荣SM2246XT主控的固态硬盘修复开卡不识别怎么短接方法

    慧荣SM2246XT主控的固态硬盘修复开卡不识别怎么短接方法一块坏了的240G的固态硬盘,电脑完全不认盘了,所以想修复一下,拆开看到主控是慧荣的SM2246XT,幸好此主控是有开卡软件的,下载也比较方便,最新的SM2246XT_MP_EnhancePageMode_MPQ1102A_DBQ0412_FWR1212A.rar修复成功率很高,但跟U盘量产不同,固态硬盘开卡是需要短接的,但很多人不知道SM2246XT的固态硬盘该怎么短接,这里就教大家。如图拆开ssd外壳后可以看到板子上有ROMMODE的字样,那里就是短接的位置了,这里的4个短接点,我们是需要用镊子两两短

  • 使用批处理命令向win server AD域中批量添加用户实现

    使用批处理命令向win server AD域中批量添加用户实现因为要用个批处理命令在WindowsServer里面批量添加域用户,所以需要使用批处理命令。我这篇是纯新手教程,在百度上搜了一些批处理命令感觉属于进阶教程,研究了两天才完成我要完成的目标。下面从头说一下:批处理bat文档建立。直接新建一个TXT文档然后把后缀名改成.bat就可以了,就是一个bat文档,双击可以运行。注意:bat文件在哪,他的运行路径就在哪。添加成功的用户

  • springcloud kafka 分布式配置中心管理

    springcloud kafka 分布式配置中心管理

  • 算法刷题LeetCode中文版_leetcode简单题

    算法刷题LeetCode中文版_leetcode简单题目录二分查找排序的写法BFS的写法DFS的写法回溯法树递归迭代前序遍历中序遍历后序遍历构建完全二叉树并查集前缀树图遍历Dijkstra算法Floyd-Warshall算法Bellman-Ford算法最小生成树Kruskal算法Prim算法拓扑排序查找子字符串,双指针模板动态规划状态搜索贪心本文的目的是收集一些典型的题目,记住其写法,理解其思想,即可做到一通百通。欢迎大家提出宝贵意见!二分查找…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号