从STN网络到deformable convolution

从STN网络到deformable convolution1  STN (SpatialTransformerNetwork)1.1    来源论文来源:https://arxiv.org/pdf/1506.02025.pdf    参考博客:    1. https://blog.csdn.net/ly244855983/article/details/80033788(论文解读)    2. https://blog.csdn.net/xbi…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

1    STN (Spatial Transformer Network)

1.1    来源

论文来源:https://arxiv.org/pdf/1506.02025.pdf

    从STN网络到deformable convolution

参考博客:

    1. https://blog.csdn.net/ly244855983/article/details/80033788(论文解读)

    2. https://blog.csdn.net/xbinworld/article/details/69049680 (梯度流动图)

    3. https://blog.csdn.net/l691899397/article/details/53641485(这里有caffe代码分析)

1.2    动机

普通CNN能够显式学习到平移不变性,隐式学习到旋转不变性、伸缩/尺度不变性(通常学的不够好),但是attention机制的成功告诉我们,与其让网络自己隐式学习某个能力,不如为它显式设计某个模块,让它更容易学习到这个能力。

因此,设计STN的目的就是为了显式地赋予网络以上各项变换(transformation)的不变性(invariance)。

1.3    网络结构

        
从STN网络到deformable convolution

如上图所示,STN由Localisationnet(定位网络),Grid generator(网格生成器)和Sampler(采样器)三部分构成。

1. Localisation Net

Localisation Net 的目标是学习空间变换参数θ,无论通过全连接层还是卷积层,LocalisationNet 最后一层必须回归产生空间变换参数θ。

    输入:特征图U ,其大小为 (H, W, C)

    输出:空间变换参数θ(对于仿射变换来说,其大小为(6,))

    结构:结构任意,比如卷积、全连接均可,但最后一层必须是regression layer来产生参数θ,记作θ= floc(U)

2. Grid Generator

该层利用LocalisationNet 输出的空间变换参数θ,将输入的特征图进行变换,这个决定了变换前后图片U、V之间的坐标映射关系。

以仿射变换为例,将输出特征图上某一位置(xit,yit)通过参数θ映射到输入特征图上某一位置(xis,yis),上标t表示target,上标s表示source,计算公式如下:

            从STN网络到deformable convolution

因此,GridGenerator的作用是,输入target坐标,计算输出source坐标,因为STN的目标是从source中的不同坐标采集灰度值“贴”到target中,从而实现target的变换。

举个例子,经过仿射变换,对原图产生了平移和旋转,使得原本倾斜的图片变正了,如下图所示:

3. Sampler

Sampler根据GridGenerator产生的坐标映射关系,把输入图片U变换成输出图片V。

在计算中, (xis,yis)往往会落在原始输入特征图的几个像素点中间,因此需要利用双线性插值来计算出对应该点的灰度值:

            从STN网络到deformable convolution

    Unmc:是输入特征图U通道c中位置为 (n, m) 的灰度值。

    Vic :是输出特征图V通道c中位置为 (xit, yit),即像素点i的灰度值。

论文中在变换时用都是标准化坐标,即xi,yi∈[−1,1]。

(双线性插值参考文献:https://zh.wikipedia.org/wiki/%E5%8F%8C%E7%BA%BF%E6%80%A7%E6%8F%92%E5%80%BC)

1.4    反向传播

Localisationnet、Grid generator、Sampler三者都是可微的,因此它们可以插入到正常的网络构架中,通过反向传播更新参数,无需额外的监督信息。

这里可以看一下反向的梯度流动:

    1. Vic关于xis、yis的偏导,用于更新Localisationnet中的参数,注意这条梯度支流传递到Localisation net就结束了。

    2. Vic关于Unmc的偏导,这是整个网络梯度向前传递的流动。

    (缺图)

原论文中还给出了Sampler的反向求导公式,即Vic关于xis、yis的偏导、Vic关于Unmc的偏导,在此就不展开介绍。

1.5    优势

STN的主要特点:

1.      模块化:STN可以插入到现有深度学习网络结构的任意位置,且只需要较小的改动。

2.      可微分性:STN是一个可微分的结构,可以反向传播,整个网络可以端到端训练。

3.      不需要额外的监督信息。

1.6    论文中的实验结果

1.6.1    Distroted MINST

对MINST数据集做了rotation(R), rotation, scale and translation (RTS), projective transformation (P), andelastic warping (E)

    Baseline:FCN,CNN。

    ST-FCN:STN直接作用在inputiamge。

    ST-CNN:STN直接作用在inputiamge。

    STN内部:都使用双线性插值,但用了不同的transformationfunctions,θ参数不同:an affine transformation(Aff 仿射变换), projective transformation (Proj 投影变换),and a 16-point thin plate spline transformation (TPS 16点薄板样条变换)。

实验结果如下:

        从STN网络到deformable convolution

右边的图片是CNN识别失败但STN识别成功的例子。(a)输入图片(b)STN中grid的可视化 (c)STN输出图片

(更多的关于实验中网络结构的参数说明,请参见STN原论文)

1.6.2    Street View House Numbers

图片来自真实世界中的街景房屋编号,数据集约有200k张图片,每张图片中包含一个数字序列,数字范围是从1到5。

    Input data:从原图中crop得到包含数字编号的小图片,大小是64×64或128×128。

    Baseline:charactersequence CNN model,11个隐层,5个独立的softmaxclassifier。

    ST-CNN Single:在Baseline的inputdata上引入一个STN,其localisation net是一个4层CNN

    ST-CNN Multi:在Baseline的前4个卷积层之前,分别引入一个STN(一共4个STN),其localisationnet是2层32个神经元的FCN。

上述STN都使用了双线性差值和仿射变换。

实验结果如下:

        从STN网络到deformable convolution

(a)ST-CNN Multi的网络结构示意图 (b)ST-CNN Muli的可视化效果

1.6.3    CUB-200-2011 birds dataset

CUB-200-2011birds dataset来自《TheCaltech-UCSD Birds-200-2011 dataset》,包括6k训练图片,5.8k测试图片,包括200种鸟类,是多标签的。在该论文的实验中,只用了classlabel来训练。

    Baseline:anInception architecture with batch normalisation pre-trained on ImageNet andfine-tuned on CUB,达到82.3%的正确率。

    ST-CNN:2个或4个并行的STN,结构如下:

            从STN网络到deformable convolution

实验结果如下,4xST-CNN比Baseline提高了1.8%的准确率:

        (缺图)

右上图片:2xST-CNN的可视化结果,2个STN表现为红色、绿色的两个方框,有趣的是,看起来红框检测了鸟的头部,绿框检测了鸟的身体。

右下图片:4xST-CNN的可视化结果,有类似效果。

1.7    应用场景

1. 在OCR中,用STN对扭曲倾斜的文本区域整体进行校正,帮助后续网络更好地识别内容;STN也可以对单个字符做校正。

2. 在OCR中,从一整幅原图中找到若干个不同尺寸的文本区域,用STN得到固定尺寸的featuremap,再对其中的内容进行文字识别。

 

1.8    SpatialTransformerOP MXNet

SpatialTransformerOP是一个inference过程的算子,要求输入inputdata和locatisation net的参数θ。

1.8.1    接口

SpatialTransformer(data=None, loc=None,target_shape=_Null, transform_type=_Null, sampler_type=_Null, name=None,attr=None, out=None, **kwargs)

Parameters:

·         data (Symbol) – Input data to the SpatialTransformerOp.

·         loc (Symbol) – localisation net, the output dim should be 6 when transform_type is affine. You shold initialize the weight and bias with identity tranform.

·         target_shape (Shape(tuple), optional, default=[0,0]) – output shape(h, w) of spatialtransformer: (y, x)

·         transform_type ({‘affine’}, required) – transformation type

·         sampler_type ({‘bilinear’}, required) – sampling type

·         name (string, optional.) – Name of the resulting symbol.

Returns:

The result symbol.

Return type:

Symbol

 

1.8.2    正向

1.      把target的坐标归一化到[-1,1],变成齐次坐标

2.      调用linalg_gemm,做仿射变换,即Gridgenerator过程

3.      调用BilinearSamplingForward,做双线性差值

代码在 http://rtfcode.com/xref/mxnet-0.12.1/src/operator/spatial_transformer-inl.h#L68

1.8.3    反向

1.      调用BilinearSamplingBackward

2.      调用linalg_geem

 

2    Deformable Convolution

2.1    相关研究

DeformableConvolution借鉴了之前Spatial Transformer Network的bilinearsampling的思路和具体的backpropagation方法,使用了bilinearsampling将任何一个位置的输出,转换成对于feature map的插值操作

2.2    算法原理

虽然CNN中的特征映射图和卷积核是3Dtensor,但deformable convolution是在2D空间域上运行的,并且在通道维度上保持不变。

1. 传统2D卷积

设输入特征图x,输出特征图y,卷积核w,设R表示感受野的大小,以3x3stride=1的卷积为例,R={(-1,1), (-1,0), …, (1,0), (1,1)}定义了9个位置,用N=9表示位置的个数。

于是,计算输出特征图y上的每个位置p0

            从STN网络到deformable convolution

2. 可变形卷积

仅增加了一个参数△pn,其中n=1,2, … , N,计算可变形卷积的公式为,

            从STN网络到deformable convolution

然而,△pn通常是小数,于是就引入了bilinearsampling,来计算变形后的x(p) = x(p0+pn +△pn),

            从STN网络到deformable convolution

其中,q表示p附近的4个点,函数G(·,·)表示bilinearsampling的计算,因为是双线性差值,G(·,·)有x、y两个维度,

            从STN网络到deformable convolution

其中g(a,b) = max(0, 1-|a-b|),这跟STN中的双线性差值是一致的。

            从STN网络到deformable convolution

重点来了,deformableconvolution到底学习的是什么?

可变形卷积通过一个conv层,是对inputfeature map中的每个位置,学习得到offset field,这个offsetfield的大小是和input feature map相同的。

还是以3×3 stride=1的filter为例,N=9,offset field的维度是2N=18,代表了input feature map中每个位置,对应filter的3×3区域中的9个计算点有x、y方向上的2个offset,于是2N=2*3*3=18。

学习得到offsetfield,就能计算x(p0+pn +△pn),就能根据公式计算出outputfeature map中的y(p0)了。

2.3    应用

注意到,deformable convolution的输入输出特征图尺寸与标准卷积是相同的,因此它可以取代在现有的CNNs中原有的卷积模块的位置。         

           

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/180576.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • js 判断字符串中是否包含某个字符include的坑「建议收藏」

    js 判断字符串中是否包含某个字符include的坑「建议收藏」方法一indexOf()(推荐)varstr=”123″;console.log(str.indexOf(“3″)!=-1);//trueindexOf()方法可返回某个指定的字符串值在字符串中首次出现的位置。如果要检索的字符串值没有出现,则该方法返回-1。方法二test()varstr=”123”;varreg=RegExp(/3/);console.log(reg.test(str));//truetest()方法用于检索字

  • Okio基本使用以及源码分析

    Okio基本使用以及源码分析Okio是什么在OkHttp的源码中经常能看到Okio的身影,所以单独拿出来学习一下,作为OkHttp的低层IO库,Okio确实比传统的java输入输出流读写更加方便高效。Okio补充了java.io和java.nio的不足,使访问、存储和处理数据更加容易,它起初只是作为OKHttp的一个组件,现在你可以独立的使用它来解决一些IO问题。先看下okio库中类之间的关系:okio中最关键的是对于缓存队列的管理,这些优化操作使得okio在复制数据的时候可以减少拷贝次数,来看下okio中数据保存的数据结构是

  • ubuntu ipsec配置_ubuntu安装iperf3

    ubuntu ipsec配置_ubuntu安装iperf3ipsetpackageinUbuntuipset:administrationtoolforkernelIPsetsipset-dbgsym:debugsymbolsforipsetlibipset-dev:developmentfilesforIPsetslibipset13:libraryforIPsetslibipset13-dbgs…

  • CentOS 7 升级 Linux 内核

    CentOS 7 升级 Linux 内核升级CentOS内核参考资料1升级CentOS内核参考资料2通过/proc虚拟文件系统读取或配置内核Linux内核官网CentOS官网1.关于Linux内核Linux内核分两种:官方内核(通常是内核开发人员用)和各大Linux发行版内核(一般用户常用)。1.1官方内核在使用Docker时,发现其对Linux内核版本的最低要求…

  • 添加打印机时错误为0x0000011b_连接打印机0x000003e3

    添加打印机时错误为0x0000011b_连接打印机0x000003e3问题描述前几天共享打印机还可以使用的突然就不能打印了,删除重新安装时就提示windows无法连接到打印机,如下图:解决方案这是的补丁代号为KB5005569/KB5005573/KB5005568/KB5005566/KB5005565造成的。卸掉上述补丁即可解决问题步骤找到设置——>更新和安全—->Windows更新—->“查看更新历史记录—->卸载更新本人的经验分享,希望可以帮助到你们,如何不对的地方,可以评论留言,帮我指正一下,如果帮助了你

  • 在互联网上,没有人知道你是一条狗?「建议收藏」

    在互联网上,没有人知道你是一条狗?「建议收藏」1993年,《纽约客》(TheNewYorker)杂志刊登一则由彼得·施泰纳(PeterSteiner)创作的漫画:标题是【OntheInternet,nobodyknowsyou’readog.】这则漫画中有两只狗:一只黑狗站在电脑椅上,爪子扶着键盘。它望向站在地上、表情迷茫的另一只狗,兴奋地说:「在互联网上,没人知道你是一条狗。(OntheInternet,nobodyknowsyou’readog.)画中那只狗的台词随即成了IT界广为流传的经典笑话。那个

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号