DeepLearning之LSTM模型输入参数:time_step, input_size, batch_size的理解[通俗易懂]

DeepLearning之LSTM模型输入参数:time_step, input_size, batch_size的理解[通俗易懂]1.LSTM模型输入参数理解(LongShort-TermMemory)lstm是RNN模型的一种变种模式,增加了输入门,遗忘门,输出门。LSTM也是在时间序列预测中的常用模型。小白我也是从这个模型入门来开始机器学习的坑。LSTM的基本概念与各个门的解释已经有博文写的非常详细:推荐博文:【译】理解LSTM(通俗易懂版)这篇文章写的非常详细,生动,概念解释的非常清楚。我也是从这个…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

1. LSTM模型 输入参数理解

(Long Short-Term Memory)

lstm是RNN模型的一种变种模式,增加了输入门,遗忘门,输出门。

LSTM也是在时间序列预测中的常用模型。

小白我也是从这个模型入门来开始机器学习的坑。

LSTM的基本概念与各个门的解释已经有博文写的非常详细:推荐博文:【译】理解LSTM(通俗易懂版)

这篇文章写的非常详细,生动,概念解释的非常清楚。我也是从这个博文里开始理解的。


2. 模型参数

  1. 模型的调参是模型训练中非常重要的一部分,调整参数前的重要一步就是要理解参数是什么意思,才能帮助更好的调整参数。
  2. 但是发现在一些实战模型将代码直接放在那里,但是基本参数只是把定义写在哪里,没有生动的解释,我一开始看的时候也是一脸懵逼。
  3. 在我寻找着写参数的额定义的时候,往往看不到让小白一眼就能明白的解释。
  4. 希望从一个小白的角度来讲解我眼中的这些参数是什么意思,如果有不对,还请指出交流。

3. LSTM 的参数输入格式

1. 一般表示为[batch_size, time_step, input_size]

2. 中文解释为[每一次feed数据的行数,时间步长,输入变量个数]


3.1 分开讲解,input_size

  1. 如果你使用7个自变量来预测1个因变量,那么input_size=7,output_size=1
  2. 如果你使用8个自变量来预测3个因变量,那么input_size=8,output_size=3

这个还是比较好理解的,你的输入数据,想要通过什么变量预测什么变量应该是比较清楚的。

难点是另外两个参数的区别。


3.2 分开讲解,batch_size

  1. 如果你的数据有10000行,训练100次把所有数据训练完,那么你的batch_size=10000/100=100
  2. 如果你的数据有20000行,同样训练100次把所有数据训练完,那么你的batch_size=20000/100=200
  3. 如果你的数据有20000行,训练50次把所有数据训练完,那么你的batch_size=20000/50=400
  4. 以此类推
  5. 不过只是举个例子,实际的情况要看你的数据样本,一般的batch_size小于100,来使你的训练结果更好,一次feed太多行数据,模型容易吃撑,消化不良,可能需要健胃消食片,哈哈哈哈

3.3 分开讲解, time_step

最最最最难理解的就是这个time_step了,我也是琢磨了好久。

  1. 首先要知道,time_step是指的哪个过程?
    是不是看到的图都是在画,输入了什么,遗忘了什么,输出了什么,以为每个细胞状态都是1个time_step?
    如果这样的话,那么恭喜你,你和我一样,都是想错了,其实那些一串的流程细胞状态图都是在1个time_step!都是在1个time_step!都是在1个time_step!
  2. 是不是很惊讶,很奇怪?
  3. 那讲的是time_step的内部进行的,而不是在time_step之间。
  4. 换句话说,所谓的t-1的遗留状态也是在一个time_step里面的事情,t多少取决于time_step的取值。

此时,再来看看time_step的本身含义,时间步长,时间步长,那么一定是是和时间有关系啊!!!

4. 重点

4.1 batch_size与time_step

  1. 之前的batch_size中只是规定了一个每次feed多少行数据进去,并没有涵盖一个时间的概念进去,
  2. 而这个参数刚好就是对于时间的限制,毕竟你是做时间序列预测,所以才多了这个参数。
  3. 换句话说,就是在一个batch_size中,你要定义一下每次数据的时间序列是多少?
  4. 如果你的数据都是按照时间排列的,batch_size是100的话,time_step=10
  5. 在第1次训练的时候,是用前100行数据进行训练,而在这其中每次给模型10个连续时间序列的数据。
  6. 那你是不是以为应该是1-10,11-20,21-30,这样把数据给模型?还是不对,请看下图。

4.2 [batch_size, time_step, input_size]=[30,5,7]

time_step=n, 就意味着我们认为每一个值都和它前n个值有关系
在这里插入图片描述

  1. 如果 [batch_size, time_step, input_size]=[30,5,7]
  2. 那么,上图中,黑色框代表的就是一个batch_size中所含有的数据的量。
  3. 那么,从上到下的3个红色框就为 time_step为5的时候,每次细胞输入门所输入的数据量。
  4. 那么,列B~列H,一共7列,就为 input_size

4.3 举例

再看下图
在这里插入图片描述

time_step=n, 就意味着我们认为每一个值都和它前n个值有关系

  1. 假如没有time_step这个参数, [input_size=7,batch_size=30],一共只需要1次就能训练完所有数据。
  2. 如果有,那么变成了 [input_size=7,batch_size=30, time_step=5],需要30-5+1=26,需要26次数据连续喂给模型,中间不能停。
  3. 在26次中每一次都要把上一次产生的y,与这一次的5行连续时间序列数据一起feed进去,再产生新的y
  4. 以此往复,直到此个batch_size 结束。

结语

1. input_size 是根据你的训练问题而确定的。

2. time_step是LSTM神经网络中的重要参数,time_step在神经网络模型建好后一般就不会改变了。

3. 与time_step不同的是,batch_size是模型训练时的训练参数,在模型训练时可根据模型训练的结果以及loss随时进行调整,达到最优。


非常感谢以下作者,让我慢慢理解了参数意义,才有了以上学习笔记!

参考资料:

菜鸡的自我拯救,RNN 参数理解

视觉弘毅,RNN之多层LSTM理解

MichaelLiu_dev,理解LSTM(通俗易懂版)

Andrej Karpathy,The Unreasonable Effectiveness of Recurrent Neural Networks

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/197633.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 使用 parted 对单个磁盘进行分区并进行配额「建议收藏」

    使用 parted 对单个磁盘进行分区并进行配额「建议收藏」文章目录1.实验要求2.实验步骤3.我的一次实验步骤1.实验要求虚拟机新增一个硬盘,大小大于10G,使用parted工具对磁盘进行分区,分区类型为ext4对新增分区设置磁盘配额,限制lisi用户最多允许使用200M的容量大小并最多允许创建10个文件2.实验步骤准备一个新虚拟机,我们用新环境进行实验VMware添加一块20G硬盘echo”—“>/sys/class/scsi_host/host0/scan#扫描主机fdis

  • phpstudy中的mysql

    phpstudy中的mysql

    2021年10月14日
  • 一阶倒立摆分析_倒立摆受力分析

    一阶倒立摆分析_倒立摆受力分析摆的运动是两种运动的叠加:1.平动,包含x方向和y方向。2.转动,转轴为质心。尽管物理上的转轴是其端点,但这个端点同时也是摆的受力点。在端点(非中心)施加垂直于摆臂的力,摆将绕其质心转动。  因为摆的重力作用于其转轴(质心),因此摆自身的重力对摆不施加力矩。这可以算作将质心作为转轴来分析的一个优势。   …

  • 51单片机8×8点阵屏设计(51单片机led光立方)

    1.简介本设计是以STC89C52单片机的8x8x8的LED光立方。本设计将LED光立方分成8层,分别由单片机的P1,8个IO口来控制每一层,由于采用的是共阴极所以当层电位为高电平有效,由P0口和P2的总共16个IO口来控制每层的64盏灯,低电平有效,P2口通过8个74HC573缓冲器芯片来驱动LED。这样就可以通过控制IO口的输出电平来控制每盏灯的亮灭。2.硬件设计本系统的硬件电路主要单片…

  • 5G信道建模

    5G信道建模5G毫米波一般认为毫米波波段的信道具有稀疏性,即径数远小于天线数,因此直接在角度域上通过估计各条径的AoD/AoA和增益系数做信道估计,比起在天线域上做信道估计更简单。但这么做还隐含了每条可分辨径的角度扩展很小这样的假设,在mmWavemMIMO系统中,信道估计等同于估计AoA和AoD以及每条path的散射系数,就是毫米波波段的情况。而在低频NLoS情况下,由于散射传播路径非常丰富,信道不存在稀疏性,也没有一个清晰的几何结构,因此一般建模为随机的比如Rayleigh信道。对于这种信道我们就对其整体进行估

  • mariadb 的安装及基本配置

    mariadb 的安装及基本配置文章目录一、mariadb介绍二、mariadb下载及安装三、mariadb的启停命令四、mariadb的配置五、添加用户,设置权限Navicat连接数据库一、mariadb介绍MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可。开发这个分支的原因之一是:甲骨文公司收购了MySQL后,有将MySQL闭源的潜在风险,因此社区采用分支的方式来避开这个风险。MariaDB的目的是完全兼容MySQL,包括API和命令行,使之能轻松成为MySQL的代替品。在存储引擎

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号