详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。上面两幅图片,已经很详细的展示了损失是如何产生的,以及…

大家好,又见面了,我是你们的朋友全栈君。

上一节我们说了详细展示RNN的网络结构以及前向传播,在了解RNN的结构之后,如何训练RNN就是一个重要问题,训练模型就是更新模型的参数,也就是如何进行反向传播,也就意味着如何对参数进行求导。本篇内容就是详细介绍RNN的反向传播算法,即BPTT。


首先让我们来用动图来表示RNN的损失是如何产生的,以及如何进行反向传播,如下图所示。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

上面两幅图片,已经很详细的展示了损失是如何产生的, 以及如何来对参数求导,这是忽略细节的RNN反向传播流程,我相信已经描述的非常清晰了。下图(来自trask)描述RNN详细结构中反向传播的过程。

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

有了清晰的反向传播的过程,我们接下来就需要进行理论的推到,由于符号较多,为了不至于混淆,根据下图,现标记符号如表格所示:

详细阐述基于时间的反向传播算法(Back-Propagation Through Time,BPTT)「建议收藏」

公式符号表
符号 含义

K

输入向量的大小(one-hot长度,也是词典大小)

T

输入的每一个序列的长度

H

隐藏层神经元的个数

X=\left \{ x_{1},x_{2},x_{3}....,x_{T} \right \}

样本集合

x_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻的输入

y_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻经过Softmax层的输出。

\hat{y}_{t}\epsilon \mathbb{R}^{K\times 1}

t时刻输入样本的真实标签

L_{t}

t时刻的损失函数,使用交叉熵函数,

L_t=-\hat{y}_t^Tlog(y_t)

L

序列对应的损失函数:

L=\sum\limits_t^T L_t

RNN的反向传播是每处理完一个样本就需要对参数进行更新,因此当执行完一个序列之后,总的损失函数就是各个时刻所得的损失之和。

s_{t}\epsilon \mathbb{R}^{H\times 1}

t个时刻RNN隐藏层的输入。

h_{t}\epsilon \mathbb{R}^{H\times 1}

第t个时刻RNN隐藏层的输出。

z_{t}\epsilon \mathbb{R}^{H\times 1}

输出层的输入,即Softmax函数的输入

W\epsilon \mathbb{R}^{H\times K}

输入层与隐藏层之间的权重。

U\epsilon \mathbb{R}^{H\times H}

上一个时刻的隐藏层 与 当前时刻隐藏层之间的权值。

V\epsilon \mathbb{R}^{K\times H}

隐藏层与输出层之间的权重。

                                                                     \begin{matrix} \: \: \: \: \: \: \: \: \; \; \; \; \; \; \; \; \; \; \; \; \; s_t=Uh_{t-1}+Wx_t+b\\ \\ h_t=\sigma(s_t)\\ \\ \; \; \; \; z_t=Vh_t+c\\ \\ \; \; \; \; \; \; \; \; \; \; y_t=\mathrm{softmax}(z_t) \end{matrix}

我们对参数V,c求导比较方便,只有每一时刻的输出对应的损失与V,c相关,可以直接进行求导,即:

                                                  \frac{\partial L}{\partial V} =\sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial V} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial V} = \sum\limits_{t=1}^{\tau}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                 \frac{\partial L}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial c} = \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}

要对参数W,U,b进行更新,就不那么容易了,因为参数W,U,b虽是共享的,但是他们不只是对第t刻的输出做出了贡献,同样对t+1时刻隐藏层的输入s_{t+1}做出了贡献,因此在对W,U,b参数求导的时候,需要从后向前一步一步求导。

假设我们在对t时刻的参数W,U,b求导,我们利用链式法则可得出:

                                                                   \frac{\partial L}{\partial W}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial W}

                                                                    \frac{\partial L}{\partial U}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial U}

                                                                    \frac{\partial L}{\partial b}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial b}

我们发现对W,U,b进行求导的时候,都需要先求出\frac{\partial L}{\partial h_{t}},因此我们设:

                                                                         \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}

那么我们现在需要先求出\delta ^{t},则:

                                                                        \frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}=V^{T}(y_{t}-\hat{y}_{t})

                                                                   \begin{matrix} \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}=U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\delta ^{t+1}) \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\sigma ^{'}(z_{t+1}))\delta ^{t+1} \\ \\ =U^{T}diag(1-h_{t+1}^{2})\delta ^{t+1} \end{matrix}

注:在求解激活函数导数时,是将已知的部分求导之后,然后将它和激活函数导数部分进行哈达马乘积。激活函数的导数一般是和前面的进行哈达马乘积,这里的激活函数是双曲正切,用矩阵中对角线元素表示向量中各个值的导数,可以去掉哈达马乘积,转化为矩阵乘法。

则:

                                                                \begin{matrix} \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}\\ \\\: \; \; \; \; \; \; \; \; \; \; \; \; \; \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ \: \: \: \: \: \: \: \: \: \: \: \: \: \: \: \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1} (1-h_{t+1}^{2}) \end{matrix}

我们求得\delta ^{t},之后,便可以回到最初对参数的求导,因此有:

                                                           \frac{\partial L}{\partial W} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial W} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                            \frac{\partial L}{\partial b}= \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial b} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                           \frac{\partial L}{\partial U} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial U} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

有了各个参数导数之后,我们可以进行参数更新:

                                                          W^{'}=W-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                           U^{'}=U-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

                                                          V^{'}=V-\theta \sum\limits_{t=1}^{T}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                           b^{'}=b-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                          c^{'}=c- \theta \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}


参考:

刘建平《循环神经网络(RNN)模型与前向反向传播算法

李弘毅老师《深度学习》

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152338.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号