BPTT

BPTTRNN的BP——BackPropagationThroughTime.参考:零基础入门深度学习(5)-循环神经网络。知乎。1   defbackward(self,sensitivity_array,2activator):3”’4实现BPTT…

大家好,又见面了,我是你们的朋友全栈君。

RNN 的 BP —— Back Propagation Through Time.

参考:零基础入门深度学习(5) – 循环神经网络知乎

 

BPTT

BPTT

BPTT

BPTT

BPTT

 1   def backward(self, sensitivity_array, 
 2                  activator):
 3         '''
 4         实现BPTT算法
 5         '''
 6         self.calc_delta(sensitivity_array, activator)
 7         self.calc_gradient()
 8     def calc_delta(self, sensitivity_array, activator):
 9         self.delta_list = []  # 用来保存各个时刻的误差项
10         for i in range(self.times):
11             self.delta_list.append(np.zeros(
12                 (self.state_width, 1)))
13         self.delta_list.append(sensitivity_array)
14         # 迭代计算每个时刻的误差项
15         for k in range(self.times - 1, 0, -1):
16             self.calc_delta_k(k, activator)
17     def calc_delta_k(self, k, activator):
18         '''
19         根据k+1时刻的delta计算k时刻的delta
20         '''
21         state = self.state_list[k+1].copy()
22         element_wise_op(self.state_list[k+1],
23                     activator.backward)
24         self.delta_list[k] = np.dot(
25             np.dot(self.delta_list[k+1].T, self.W),
26             np.diag(state[:,0])).T
27     def calc_gradient(self):
28         self.gradient_list = [] # 保存各个时刻的权重梯度
29         for t in range(self.times + 1):
30             self.gradient_list.append(np.zeros(
31                 (self.state_width, self.state_width)))
32         for t in range(self.times, 0, -1):
33             self.calc_gradient_t(t)
34         # 实际的梯度是各个时刻梯度之和
35         self.gradient = reduce(
36             lambda a, b: a + b, self.gradient_list,
37             self.gradient_list[0]) # [0]被初始化为0且没有被修改过
38     def calc_gradient_t(self, t):
39         '''
40         计算每个时刻t权重的梯度
41         '''
42         gradient = np.dot(self.delta_list[t],
43             self.state_list[t-1].T)
44         self.gradient_list[t] = gradient

 

 1 class RNN2(RNN1):
 2     # 定义 Sigmoid 激活函数
 3     def activate(self, x):
 4         return 1 / (1 + np.exp(-x))
 5 
 6     # 定义 Softmax 变换函数
 7     def transform(self, x):
 8         safe_exp = np.exp(x - np.max(x))
 9         return safe_exp / np.sum(safe_exp)
10 
11     def bptt(self, x, y):
12         x, y, n = np.asarray(x), np.asarray(y), len(y)
13         # 获得各个输出,同时计算好各个 State
14         o = self.run(x)
15         # 照着公式敲即可 ( σ'ω')σ
16         dis = o - y
17         dv = dis.T.dot(self._states[:-1])
18         du = np.zeros_like(self._u)
19         dw = np.zeros_like(self._w)
20         for t in range(n-1, -1, -1):
21             st = self._states[t]
22             ds = self._v.T.dot(dis[t]) * st * (1 - st)
23             # 这里额外设定了最多往回看 10 步
24             for bptt_step in range(t, max(-1, t-10), -1):
25                 du += np.outer(ds, x[bptt_step])
26                 dw += np.outer(ds, self._states[bptt_step-1])
27                 st = self._states[bptt_step-1]
28                 ds = self._w.T.dot(ds) * st * (1 - st)
29         return du, dv, dw
30 
31     def loss(self, x, y):
32         o = self.run(x)
33         return np.sum(
34             -y * np.log(np.maximum(o, 1e-12)) -
35             (1 - y) * np.log(np.maximum(1 - o, 1e-12))
36         )

 

转载于:https://www.cnblogs.com/niuxichuan/p/8094800.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152280.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 查看win7系统激活信息时候常用的一些命令

    查看win7系统激活信息时候常用的一些命令1.slmgr.vbs-dli  显示:操作系统版本、部门产品密钥、许可证状态  2.slmgr.vbs-dlv  显示:最为详尽的激活信息,包括:激活ID、安装ID、激活截止日期  3.slmgr.vbs-xpr  显示:是不是彻底激活  4.slmgr.vbs-ipk  更换WIN7序列号  5.slmgr.vbs-ato  激活WIN7 …

  • 如果要将二叉树{16,14,10,8,7,9,3}_二叉分枝

    如果要将二叉树{16,14,10,8,7,9,3}_二叉分枝有一棵二叉苹果树,如果树枝有分叉,一定是分两叉,即没有只有一个儿子的节点。这棵树共 N 个节点,编号为 1 至 N,树根编号一定为 1。我们用一根树枝两端连接的节点编号描述一根树枝的位置。一棵苹果树的树枝太多了,需要剪枝。但是一些树枝上长有苹果,给定需要保留的树枝数量,求最多能留住多少苹果。这里的保留是指最终与1号点连通。输入格式第一行包含两个整数 N 和 Q,分别表示树的节点数以及要保留的树枝数量。接下来 N−1 行描述树枝信息,每行三个整数,前两个是它连接的节点的编号,第三个数是这根树枝上

  • python基础(3)列表list[通俗易懂]

    python基础(3)列表list[通俗易懂]列表列表特点:是一种序列结构,与元组不同,列表具有可变性,可以追加、插入、删除、替换列表中的元素新增元素appendappend添加一个对象,可以是任意类型a=['zhangsa

  • pascal voc数据集下载_目标检测分类

    pascal voc数据集下载_目标检测分类一、简介PASCALVOC挑战赛主要有ObjectClassification、ObjectDetection、ObjectSegmentation、HumanLayout、ActionClassification这几类子任务PASCAL主页与排行榜PASCALVOC2007挑战赛主页、PASCALVOC2012挑战赛主页、PASC…

  • Python(含PyCharm及配置)下载安装以及简单使用(Idea)「建议收藏」

    Python(含PyCharm及配置)下载安装以及简单使用(Idea)「建议收藏」下载Python官网下载地址:Python下载不同参数解释,小伙伴们根据自己情况进行下载即可(此处博主用的是3.7.3版本):–web-basedinstaller:在线安装。下载的是一个exe可执行程序,双击后,该程序自动下载安装文件进行安装。网络安装版,需联网–executableinstaller:程序安装。下载的是一个exe可执行程序,双击进行安装。本地安装,可执行程序(***)–embeddablezipfile:解压安装。下载的是一个压缩文件,解压后即表示安装完成。嵌入式版

  • java 二维数组 数据库_java 二维数组如何存入数据库

    java 二维数组 数据库_java 二维数组如何存入数据库usingSystem;usingSystem.Linq;usingSystem.Text;usingSystem.Windows.Forms;usingSystem.Xml;usingSystem.Xml.Serialization;usingSystem.IO;namespaceWindowsFormsApplication1{publicpartialclassForm…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号