【笔记】MAML-模型无关元学习算法「建议收藏」

【笔记】MAML-模型无关元学习算法「建议收藏」[TOC]论文信息:FinnC,AbbeelP,LevineS.Modelagnosticmetalearningforfastadaptationofdeepnet

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

论文信息:

Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.

一、摘要

  • 学习的目标是在各种学习任务上训练一个模型,这样它就可以使用少量的训练样本来解决新的学习任务。

  • 本文提出了一种与模型无关的元学习算法,它适用于任何基于梯度下降进行训练的模型,并且适用于各种学习问题,如分类(Classification)、回归(Regression)和强化学习(Reinforcement Learning)

  • 在本文提出的方法中,模型的参数被显式地训练,模型在处理新任务时,只需几次的梯度更新以及少量的训练数据就能取得良好的泛化性能。

  • 该方法在两种few-shot图像分类基准(Omniglot和 MiniImagenet)上取得了较好的性能,在few-shot回归上取得了较好的效果,并利用神经网络策略加速了策略梯度强化学习的微调。

二、背景

  • 显式训练与隐式训练

    参考显函数与隐函数的概念:

    • 隐函数:能确定y与x之间关系的方程,F(x,y)=0。x与y混杂在一起。有些隐函数可显化为显函数。
    • 显函数:用y=f(x)表示的函数。x与y明显区分。
    • 函数是方程,方程不一定是函数。因为函数需要实现一个数域到另一个数域的映射,而方程只要是含有未知数的等式即可。

    这样模型参数的显式训练与隐式训练就可以理解为因果区分与因果混杂的情况。

    • 隐式训练:没有明确的表达式来对目标参数进行更新。
    • 显式训练:存在明确的表达式来更新目标参数。
  • 参数方法与非参数方法

    • 参数方法(parametric method):根据先验知识假定模型服从某种分布,然后利用训练集估计出模型参数。这种方法中模型的参数固定,不随数据点的变化而变化。
    • 非参数方法(parametric method):基于记忆训练集,在预测新样本值时每次都会重新训练数据,得到新的参数值。参数的数目随数据点的变化而变化。
  • Hessian Matrix(海森矩阵)

    • 海塞矩阵(Hessian Matrix),又译作海森矩阵,是一个多元函数的二阶偏导数构成的方阵。

    • 处理一元函数极值问题,如\(f(x)=x^2\) ,我们会先求一阶导数,即 \(f^{\prime}(x)=2x\) ,然后根据费马定理——极值点处的一阶导数一定等于 0。但这仅是一个必要条件,而非充分条件。如 \(f(x)=x^3\),显然只检查一阶导数是不足以下定论的。所以进行二次求导,得出以下规律:

      • 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)<0\) ,则\(f(x)\) 在此点处取得局部极大值;
      • 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)>0\) ,则\(f(x)\) 在此点处取得局部极小值;
      • 如果一阶导数\(f^{\prime}(x)=0\) 且二阶导数\(f^{\prime \prime}(x_0)=0\) ,则无法确定
    • 处理多元函数极值问题,则需要首先对每个变量求偏导,令其为零,定位极值点的可能位置,然后利用二阶导数判断是极大值还是极小值。\(n\) 元函数有 \(n^2\) 个二阶导数,因此构成海森矩阵

      \[ \mathbf{H}=\begin{bmatrix} \frac{\partial^2f}{\partial x_1^2} & \frac{\partial^2f}{\partial x_1\partial x_2} & \cdots &\frac{\partial^2f}{\partial x_1\partial x_n} \\ \frac{\partial^2f}{\partial x_2\partial x_1} & \frac{\partial^2f}{\partial x_2^2} & \cdots &\frac{\partial^2f}{\partial x_2\partial x_n} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial^2f}{\partial x_n\partial x_1}&\frac{\partial^2f}{\partial x_n\partial x_2}&\cdots &\frac{\partial^2f}{\partial x_n^2} \end{bmatrix} \]

      • 海森矩阵的极值判断阶段如下:
        • 如果是正定矩阵,则临界点处是一个局部极小值
        • 如果是负定矩阵,则临界点处是一个局部极大值
        • 如果是不定矩阵,则临界点处不是极值
  • 元学习问题引入

    • 元学习过程实际上是一个创造一个高级代理的过程,这个代理在处理新任务的新数据时,能将先验知识整合进来并且能避免过拟合,即在不同任务之间具备泛化能力。

      高级代理可以理解为创造模型的模型,或者是一组模型参数,它能够根据不同的任务生成不同的模型参数,这套模型参数能够在新任务给定的新数据上快速的学习,适应任务的需要。

    • 为了得到具有快速适应能力的模型,元学习训练一般以Few-Shot Learning(少样本学习)的形式进行。

      • Few-Shot,可以分为1~k shot,即在训练过程中提供给模型1~k个样本数据,让模型进行学习。

      • 注意与Small Sample Learning(SSL,小样本学习)进行区分。后者的范围比前者更加广泛,具体参见Small Sample Learning in Big Data Era

    • 通过少量的样本数据构建成一个任务,然后让元学习模型在许多依此法创建的任务上进行训练学习,这样,经过训练的元学习模型就能凭借少量的数据和几次的训练快速适应新的任务了。

      实际上,元学习模型的训练过程是以一整个一整个的任务作为”训练数据样本“的。

  • 元学习问题的公式化表达

    • 概念定义

      • 定义一个模型,用\(f\)表示。模型\(f\)能实现观察值\(x\)到输出值\(a\)的映射。

      • 定义单个任务\(T\)

        \[\mathcal{T= \left\{ L(\mathrm{x_1,a_1,\dots,x_H,a_H}),q(\mathrm{x_1}),q(\mathrm{x_{t+1}|x_t,a_t}),\mathrm{H}\right\}} \]

        • \(\mathcal{L}\)表示损失函数,\(\mathcal{L(\mathrm{x_1,a_1,\dots,x_H,a_H})}\rightarrow \mathbb{R}\)
        • \(\mathcal{q}(\mathrm{x_1})\)表示初始观测变量的分布。
        • \(\mathcal{q(\mathrm{x_{t+1}|x_t,a_t})}\)表示转移分布
        • \(\mathrm{H}\)表示跨度(Episode Length),对于i.i.d(独立同分布)监督学习问题,H=1。
      • 期望模型适应的任务的分布\(p(\mathcal{T})\)

    • 学习过程

      • 初始化:随机初始化元学习模型参数\(\theta\),各子任务模型的初始化参数是对\(\theta\)的拷贝。

      • 元训练

        1. \(p(\mathcal{T})\)中抽取任务\(\mathcal{T_i}\)
        2. \(\mathcal{q(i)}\)中抽取\(\mathrm{K}\)个样本;
        3. 用这\(\mathrm{K}\)个样本对任务\(\mathcal{T_i}\)进行训练,得到相应的损失\(\mathcal{L_{T_i}}\),并对该任务的模型参数进行梯度更新;
        4. 在新的数据样本上测试更新后的网络,得到错误情况。
      • 元测试

        1. 根据各个任务更新后的网络的表现(test error)求初始化参数的梯度,并对元学习模型的参数其进行更新;
        2. 测试其在元测试集任务上的表现,即为元学习模型的最终表现。

三、介绍

  • 本文提出的MAML算法的关键思想:训练模型的初始化参数,使模型能在来自新任务的少量数据上对参数执行数次(1~多次)的梯度更新后能得到最佳的表现。

    • 特征学习的角度理解——MAML算法试图建立一种模型的内部表示,这种内部表示广泛适用于许多任务。这样在处理新的任务时,只需对模型参数进行简单的微调就能产生较好的结果。

    • 动态系统的角度理解——MAML的学习过程就是要让新任务的损失函数对于参数的敏感度最大化。当具有较高的敏感度时,参数的微小的局部变化就可以导致任务损失的巨大提升。

      动态系统:若系统在t0时刻的响应y(t0),不仅与t0时刻作用于系统的激励有关,而且与区间(-∞,t0)内作用于系统的激励有关,这样的系统称为动态系统。

  • 本文的主要贡献包括以下几个方面:

    1. 提出了一种元学习的简单模型以及与任务无关的算法,通过训练模型参数,使得模型参数只要经过少量次数的梯度更新就能实现在新任务上的快速学习。
    2. 在不同的模型,如全连接和卷积网络,以及不同领域上,如少样本回归、图片分类和增强学习上验证了本文提出的算法。
    3. 本文提出的方法通过使用少量参数,能够与目前最先进的专门用于监督分类的one-shot 学习算法媲美,并且能够应用于回归任务和加速任务可变情况下的强化学习过程。

四、实现

  • MAML算法的实现直觉(Intuition)是模型的某些内部表示更容易在不同的任务之间转换。比如存在某种内部表示能够适用于任务分布\(\mathcal{p(T)}\)中的所有任务而不是某一个具体的任务。由于最终模型会在新任务上使用基于梯度下降的学习规则进行微调,所以可以以一种显式的方式去学习一个具备这种规则的模型。

    这种待学习的规则可以理解为一组对任务变化敏感的模型参数,当参数沿着任务的损失梯度方向变化时,可以使得任务损失得到较大的改善。

  • 原理图如下:

    MAML原理图

    • \(\theta\) 是已经优化过的模型参数表示。
    • \(\theta\) 沿着新任务损失梯度方向变化时,会使得任务损失大幅改善,从而得到对于新任务的最佳模型参数 \(\theta^{\star}\)
  • 算法描述:

    MAML算法

    • 模型由函数 \(f_{\theta}\) 表示,该函数由参数 \(\theta\) 决定。

    • 整个算法分为两个循环:

      • 两者共享模型参数 \(\theta\)
      • 两者的梯度更新的学习率分别由超参数 \(\alpha\)\(\beta\) 表示
      • 内循环计算各子任务的损失 \(\mathcal{L_{T_{i}}}\) 和进行一至多次梯度更新后的参数 \(\theta^{‘}_{i}\)
      • 外循环根据内循环的优化参数在新任务上重新计算损失,并计算其对初始参数的梯度,然后对初始参数进行SGD梯度更新。
      • 重复内外循环,就可以得到元学习模型对于任务分布$ \mathcal{p(T)}$的最佳参数
    • 注意

      • 拥有“最佳参数”的模型在处理新任务时,由于具备了先验知识,所以只需进行微调就能产生较好的效果。
      • 外循环又称之为元优化(meta-optimization)
      • 为了适应不同的任务,内循环中的模型参数会演化成 \(\theta^{\prime}\)。而外循环中模型参数需要等到内循环中的所有任务的模型参数都优化后再进行更新。
      • 由于存在一个嵌套关系,外层的梯度更新依赖内层的梯度,因此就会出现二阶导数(梯度的梯度)的计算,需要使用到海森向量积(Hessian-Vector Product)
      • 在论文中,作者提出了一种近似算法,利用一阶梯度近似代替二阶梯度,形成FOMAML(First-Order MAML)算法,具体公式推导过程,见MAML讲解-李弘毅
  • 算法扩展:

    • 监督学习(Supervised Learning):算法中的公式(2)和公式(3)分别指代下面的两个损失函数。

      【笔记】MAML-模型无关元学习算法「建议收藏」

      • 分类(Classification)任务的损失函数采用交叉熵(cross entropy)

        \[\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}y^{(j)}\log f_{\phi}(x^{(j)})+(1-y^{(j)})\log(1-f_{\phi}(x^{(j)})) \]

      • 回归(Regression)任务的损失函数采用均方差(mean-squared error)

        \[\mathcal{L_{T_i}(f_\phi)}=\sum_{x^{(j)},y^{(j)}\sim \mathcal{T_i}}\begin{Vmatrix} f_{\phi}(x^{(j)})-y^{(j)}\end{Vmatrix}_2^2 \]

    • 强化学习(Reinforcement Learning):算法中的公式(4)指代下面的损失函数。

      【笔记】MAML-模型无关元学习算法「建议收藏」

      • 强化学习损失函数

        • 强化学习过程基于马尔可夫决策过程(Markov Decision Porcess)。
        • 具体细节还未深入了解,待补充……
        \[\mathcal{L_{T_i}(f_\phi)}=-\mathbb{E}_\mathcal{x_t,a_t\sim f_\phi,q_{T_i}}[\sum_{t=1}^H R_i(x_t,a_t)] \]

五、实验

实验代码:

  • 回归(正弦曲线)

    • 通过将MAML算法模型与预训练模型比较,分别提供K=5和K=10个样本数据,进行回归拟合。可以看到:【笔记】MAML-模型无关元学习算法「建议收藏」

      • 在没有提供任何数据点的情况下,MAML由于已经学习到了正弦波的周期结构,所以能够对曲线进行一定的评估;
      • 对于预训练模型,由于输出与已学习到的先验知识冲突,导致模型无法找到一个合适的表示形式,从而无法通过少量的样本进行拟合推断。
    • 比较MAML和预训练模型的学习曲线可以得出:

      【笔记】MAML-模型无关元学习算法「建议收藏」

      • MAML算法通过少量次数的梯度更新就能实现较低的错误率,没有对少量的数据点过拟合,达到收敛。
      • 预训练模型则由于缺乏泛化能力,对与少量数据点,很容易过拟合。
  • 分类

    • 通过将MAML模型以及简化后的FOMAML模型与用于Few-Shot Learning 分类的主流模型在Omiglot和MiniImagenet数据集上比较,可以发现:

      【笔记】MAML-模型无关元学习算法「建议收藏」

      • MAML无视数据集差异、数据点多少以及网络结构差异,都有优异的表现。

      • FOMAML模型的表现与MAML的表现非常接近,但是两者的计算消耗却不同,FOMAML的计算复杂度要明显低于MAML,这一点也是值得进一步研究的问题。

        对此,作者推测在大多数情况下,损失函数的二阶导数非常接近零,因而对模型表现没有产生太大的影响。

        On First-Order Meta-Learning Algorithms一文中,作者用泰勒公式,对导数进行了展开分析,揭露了深层次的原因。

  • 强化学习

    【笔记】MAML-模型无关元学习算法「建议收藏」

六、总结

  • 提出了一种不引入任何学习参数(实际上增加了学习率\(\alpha 和 \beta\))的通过梯度下降学习模型参数的元学习方法。
  • MAML可以以与任何适合于基于梯度的训练的模型表示,以及任何可微分的目标(包括分类、回归和强化学习)相结合。
  • MAML只产生一个权值初始化,所以可以使用任意数量的数据和任意数量的梯度步长来执行自适应。
  • MAML可以使用策略梯度和非常有限的经验来适应RL代理。
  • 重用来自过去任务的知识可能是构建高容量可伸缩模型(如深度神经网络)的一个关键因素,该模型能够使用小数据集进行快速训练。
  • 这种元学习技术可以应用于任何问题和任何模型,可以使多任务初始化成为深度学习和强化学习的标准组成部分。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/167621.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • java string.split()用字符串分割_java 字符串分割

    java string.split()用字符串分割_java 字符串分割最近写代码时遇到自字符串分割和截取的问题,在此总结一下。字符串的分割:一般自字符串的分割常用的方法是java.lang包中的String.split()方法,返回是一个字符串数组。语法:publicString[]split(Stringregex,intlimit)参数: regex–正则表达式分隔符。 limit–分割的份数。…

  • windows下面编译ucosII操作系统

    windows下面编译ucosII操作系统       ucos是一款在嵌入式系统上应用的实时操作系统,为了调试和学习(我们部门负责DSP、MCU、ARM到服务器的各种程序),有必要再windows下面模拟运行,我在一个德国网站上找到了一份移植过的代码,经过我的小小修改,已经可以用VS2010和Dev-C++(MinGw编译器)上编译运行。 运行过程中发现2个编译器编译出来的程序运行结果并不相同,看来2种编译器在实现…

  • GPIB-VC编程

    GPIB-VC编程CompilingandLinkingVISAPrograms(C/C++)Thissectionprovidesasummaryofimportantcompiler-specificconsiderationswhendevelopingWin32applications.1.LinkingtoVISALibrariesYourapp

  • 【数字图像处理】C++读取、旋转和保存bmp图像文件编程实现

    【数字图像处理】C++读取、旋转和保存bmp图像文件编程实现通过我这些天用C++读写bmp图像的经历,摸索再摸索,终于对bmp文件的结构、操作有了一定的了解,下面就大概介绍bmp图片纯C++的读取、旋转和保存的实现过程。要用C++读取bmp图片文件,首先要弄清楚bmp格式图片文件的结构。可以参考这篇文章:http://blog.csdn.net/xiajun07061225/article/details/5813726有几点需要注意的是:在读

  • BP神经网络预测模型(神经网络算法模型)

    学习率一般在(0,0.1)区间上取值.隐含层节点数量(√为开根号):①m=(√(i+j))+α②m=log2(i)③m=√(i*j)m:隐含层节点i:输入层节点数j:输出层节点数α:1-10之间的常数

  • 整理:FPGA选型[通俗易懂]

    整理:FPGA选型[通俗易懂]针对性整理下FPGA选型问题一、获取芯片资料:要做芯片的选型,首先就是要对有可能要面对的芯片有整体的了解,也就是说要尽可能多的先获取芯片的资料。现在FPGA主要有4个生产厂家,ALTERA,XIL

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号