大家好,又见面了,我是你们的朋友全栈君。
摘要:
在前面的文章里面,RNN训练与BP算法,我们提到了RNN的训练算法。但是回头看的时候在时间的维度上没有做处理,所以整个推导可能存在一点问题。
那么,在这篇文章里面,我们将介绍bptt(Back Propagation Through Time)算法如在训练RNN。
关于bptt
这里首先解释一下所谓的bptt,bptt的思路其实很简单,就是把整个RNN按时间的维度展开成一个“多层的神经网络”。具体来说比如下图:
既然RNN已经按时间的维度展开成一个看起来像多层的神经网络,这个时候用普通的bp算法就可以同样的计算,只不过这里比较复杂的是权重共享。比如上图中每一根线就是一个权重,而我们可以看到在RNN由于权重是共享的,所以三条红线的权重是一样的,这在运用链式法则的时候稍微比较复杂。
正文:
首先,和以往一样,我们先做一些定义。
hti=f(netthi)
netthi=∑m(vimxtm)+∑s(uisht−1s)
nettyk=∑mwkmhtm
最后一层经过softmax的转化
otk=enettyk∑k′enettyk′
在这里我们使用交叉熵作为Loss Function
Et=−∑kztklnotk
我们的任务同样也是求 ∂E∂wkm 、 ∂E∂vim 、 ∂E∂uim 。
注意,这里的 E 没有时间的下标。因为在RNN里,这些梯度分别为各个时刻的梯度之和。
即:
∂E∂wkm=∑stept=0∂Et∂wkm
∂E∂vim=∑stept=0∂Et∂vim
∂E∂uim=∑stept=0∂Et∂uim 。
所以下面我们推导的是 ∂Et∂wkm 、 ∂Et∂vim 、 ∂Et∂uim 。
我们先推导 ∂Et∂wkm 。
∂Et∂wkm=∑k′∂Et∂otk′∂otk′∂nettyk∂nettyk∂wkm=(otk−ztk)∗htm 。(这一部分的推导在前面的文章已经讨论过了)。
在这里,记误差信号:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk) (后面会用到)
对于 ∂Et∂vim 、 ∂Et∂uim 其实是差不多的,所以这里详细介绍其中一个。这两个导数也是RNN里面最复杂的。
推导: ∂Et∂vim
∂Et∂vim=∑tt′=0∂Et∂nett′hi∂nett′hi∂vim
对于这个式子第一次看可能有点懵逼,这里稍微解释一下:
从式: hti=f(∑m(vimxtm)+∑s(uisht−1s)) 中我们可以看到, vim 影响的是所有时刻的 netthi,t=0,1,2,....step 。所以当 Et 对 vim 求偏导的时候,由于链式法则需要考虑到所有时刻的 netthi 。
下面分成两部分来求 ∂Et∂nett′hi , ∂nett′hi∂vim. 。
第一部分: ∂Et∂nett′hi 。
这里我们记 δ(t′,t)i=∂Et∂nett′hi (误差信号,和前面文章一样)。
(由于带着符号去求这两个导数会让人看起来非常懵逼,所以下面指定具体的值,后面抽象给出通式)
假设共3个时刻,即t=0,1,2。
对于 t=2 , t′=2 时:
( E2 表示第2个时刻(也是最后一个时刻)的误差)
( net2hi 表示第2个时刻隐藏层第i个神经元的净输入)
具体来说: ∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi
对于 ∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i
由于 δ(output,t)k=∂Et∂nettyk
所以,我们有:
∂E2∂h2i=∑k′∂E2∂net2yk′∂net2yk′∂h2i=∑k′δ(output,2)k′∂net2yk′∂h2i=∑k′δ(output,2)k′wk′i
综上:
δ(2,2)i=∂E2∂net2hi=∂E2∂h2i∂h2i∂net2hi=(∑k′δ(output,2)k′wk′i)∗f′(net2hi)
对于 t=1 , t′=2 时:
( E2 表示第2个时刻的误差)
( net1hi 表示第1个时刻隐藏层第i个神经元的净输入)
具体来说: ∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi
那么 ∂E2∂h1i=∑k′∂E2∂net1yk′∂net1yk′∂h1i+∑j∂E2∂net2hj∂net2hj∂h1i 。请对比这个式子和上面 t=2 , t′=2 时的区别,区别在于多了一项 ∑j∂E2∂net2hj∂net2hj∂h1i 。这个原因我们已经在RNN与bp算法中讨论过,这里简单的说就是由于 t=1 时刻有 t=2 时刻反向传播回来的误差,所以要考虑上这一项,但是对于 t=2 已经是最后一个时刻了,没有反向传播回来的误差。
对于第一项 ∑k′∂E2∂net1yk′∂net1yk′∂h1i 其实是0。下面简单分析下原因:
上式进一步可以化为: ∑k′(∑k″∂E2∂o1k″∂o1k″∂net1yk′)∂net1yk′∂h1i 而 E2 与第1个时刻输出 o1k″ 无关。所以为0。
对于第二项 ∑j∂E2∂net2hj∂net2hj∂h1i ,我们带入 δ(t′,t)i=∂Et∂nett′hi 有:
∑j∂E2∂net2hj∂net2hj∂h1i=∑jδ(2,2)j∂net2hj∂h1i 。
同时明显有 ∂net2hj∂h1i=uji
即: ∂E2∂h1i=∑jδ(2,2)juji
综上:
δ(1,2)i=∂E2∂net1hi=∂E2∂h1i∂h1i∂net1hi=(∑jδ(2,2)j∂net2hj∂h1i)∗f′(net1hi)=(∑jδ(2,2)juji)∗f′(net1hi)
对于 t=0 , t′=2 时:
( E2 表示第2个时刻的误差)
( net0hi 表示第0个时刻隐藏层第i个神经元的净输入)。
和上面的思路一样,我们容易得到:
δ(0,2)i=∂E2∂net0hi=(∑jδ(1,2)juji)∗f′(net0hi) 。
至此,我们求完了 ∂Et∂nett′hi 。下面我们来总结一下其通式:
(∑k′δ(output,t)k′wk′i)∗f′(nett′hi),(∑jδ(t′+1,t)juji)∗f′(nett′hi),t=t′t≠t′
另外,对于 δ(output,t)k 有以下表达式:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′∂otk′∂nettyk=(otk−ztk)
最后只要求出 ∂nett′hi∂vim ,其值具体为 ∂nett′hi∂vim=xtm
最后,对于 ∂Et∂uim 其实和上面的差不多,主要是后面的部分不一样,具体来说:
∂Et∂uim=∑tt′=0∂Et∂nett′hi∂nett′hi∂uim ,可以看到就只有等式右边的第二项不一样,关键部分是一样的。 ∂nett′hi∂uim=ht′−1m
细节-1
上面提到,当只有3个时刻时,t=0,1,2。
对于误差 E2 (最后一个时刻的误差),没有再下一个时刻反向传回的误差。
那么对于 E1 (第1个时刻的误差)存在下一个时刻反向传回的误差,但是在 ∂E1∂h1i 中的第二项 ∑j∂E1∂net2hj∂net2hj∂h1i 仍然为0。是因为 ∂E1∂net2hj=0 ,因为 E1 的误差和下一个时刻隐藏层的输出没有任何关系。
总结
看起来bptt和我们之前讨论的bp本质上是一样的,只是在一些细节的处理上由于权重共享的原因有所不同,但是基本上还是一样的。
下面这篇文章是有一个简单的rnn代码,大家可以参考一下
参考文章1
代码的bptt中每一步的迭代公式其实就是上面的公式。希望对大家有帮助~
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152320.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...