siamfc-pytorch代码讲解(二):train&siamfc「建议收藏」

siamfc-pytorch代码讲解(二):train&siamfc「建议收藏」siamfc-pytorch代码讲解(二):train&siamfc一、train.py二、siamfc.py2.1SiamFCTransforms2.2Pair2.3train_step下一篇这是第二篇的siamfc-pytorch代码讲解,主要顺着程序流讲解代码,上一篇讲解在这里:siamfc-pytorch代码讲解(一):backbone&headshowme…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

siamfc-pytorch代码讲解(二):train&siamfc

这是第二篇的siamfc-pytorch代码讲解,主要顺着程序流讲解代码,上一篇讲解在这里:
siamfc-pytorch代码讲解(一):backbone&head
show me code !!

一、train.py

因为作者使用了GOT-10k这个工具箱,train.py代码非常少,就下面几行:

from __future__ import absolute_import

import os
from got10k.datasets import *

from siamfc import TrackerSiamFC

if __name__ == '__main__':
    root_dir = os.path.expanduser('~/data/GOT-10k')
    seqs = GOT10k(root_dir, subset='train', return_meta=True)

    tracker = TrackerSiamFC()
    tracker.train_over(seqs)

首先我们就需要按照GOT-10k download界面去下载好数据集,并且按照这样的文件结构放好(因为现在用不到验证集和测试集,可以先不用下,训练集也只要先下载1个split,所以就需要把list.txt中只保留前500项,因为GOT-10k_Train_000001里面有500个squences):

  |-- GOT-10k/
     |-- train/
     |  |-- GOT-10k_Train_000001/
     |  |   ......
     |  |-- GOT-10k_Train_000500/
     |  |-- list.txt

这里可以打印一下seps到底是什么,因为他是train_over的入参:

print(seqs)
# <got10k.datasets.got10k.GOT10k object at 0x000002366865CF28>
print(seqs[0])
# 这里比较多,截取一部分
# seqs[0]就是指第一个序列GOT-10k_Train_000001,返回三个元素的元组
# 第一个元素是一个路径列表,第二个是np.ndarray,第三个是字典,包含具体信息
# (['D:\\GOT-10k\\train\\GOT-10k_Train_000001\
print(seqs)
# <got10k.datasets.got10k.GOT10k object at 0x000002366865CF28>
print(seqs[0])
# 这里比较多,截取一部分
# seqs[0]就是指第一个序列GOT-10k_Train_000001,返回三个元素的元组
# 第一个元素是一个路径列表,第二个是np.ndarray,第三个是字典,包含具体信息
# (['D:\\GOT-10k\\train\\GOT-10k_Train_000001\\00000001.jpg', ...],
# array([[347., 443., 429., 272.],...[551., 467., 513., 318.]]),
# {'url': 'https://youtu.be/b0ZnfLI8YPw',...})
000001.jpg', ...],
# array([[347., 443., 429., 272.],...[551., 467., 513., 318.]]), # {'url': 'https://youtu.be/b0ZnfLI8YPw',...})

二、siamfc.py

顺着代码流看到调用了siamfc.py中类TrackerSiamFC的train_over方法,在这个类里面就是进行数据增强,构造和加载,然后进行训练,这里主要有三个地方需要深入思考一下:

2.1 SiamFCTransforms

SiamFCTransforms是transforms.py里面的一个类,主要是对输入的groung truth的z, x, bbox_z, bbox_x进行一系列变换,构成孪生网络的输入,这其中就包括了:

  • RandomStretch:主要是随机的resize图片的大小,其中要注意cv2.resize()的一点用法,可以参考我的这篇博客:cv2.resize()的一点小坑
  • CenterCrop:从img中间抠一块(size, size)大小的patch,如果不够大,以图片均值进行pad之后再crop
  • RandomCrop:用法类似CenterCrop,只不过从随机的位置抠,没有pad的考虑
  • Compose:就是把一系列的transforms串起来
  • ToTensor: 就是字面意思,把np.ndarray转化成torch tensor类型

类初始化里面针对self.transforms_zself.transforms_x数据增强方法中具体参数的设置可以参考 issue#21,作者提到在train phase和test phase embedding size不一样没太大的影响,而且255-16可以模拟测试阶段目标的移动(个人感觉这里没有完全就按照论文上来,但也不用太在意,自己可以试着改回来看哪一个效果好)。
下面具体讲里面的_crop函数:

    def _crop(self, img, box, out_size):
        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2,
            box[0] - 1 + (box[2] - 1) / 2,
            box[3], box[2]], dtype=np.float32)
        center, target_sz = box[:2], box[2:]

        context = self.context * np.sum(target_sz)
        size = np.sqrt(np.prod(target_sz + context))
        size *= out_size / self.exemplar_sz

        avg_color = np.mean(img, axis=(0, 1), dtype=float)
        interp = np.random.choice([
            cv2.INTER_LINEAR,
            cv2.INTER_CUBIC,
            cv2.INTER_AREA,
            cv2.INTER_NEAREST,
            cv2.INTER_LANCZOS4])
        patch = ops.crop_and_resize(
            img, center, size, out_size,
            border_value=avg_color, interp=interp)
        
        return patch

因为GOT-10k里面对于目标的bbox是以ltwh(即left, top, weight, height)形式给出的,上述代码一开始就先把输入的box变成center based,坐标形式变为[y, x, h, w],结合下面这幅图就非常好理解: y = t + h 2 y=t+\frac{h}{2} y=t+2h x = l + w 2 x=l+\frac{w}{2} x=l+2w
the annotation form of bbox in GOT-10k
crop_and_resize

def crop_and_resize(img, center, size, out_size,
                    border_type=cv2.BORDER_CONSTANT,
                    border_value=(0, 0, 0),
                    interp=cv2.INTER_LINEAR):
    # convert box to corners (0-indexed)
    size = round(size)  # the size of square crop
    corners = np.concatenate((
        np.round(center - (size - 1) / 2),
        np.round(center - (size - 1) / 2) + size))
    corners = np.round(corners).astype(int)

    # pad image if necessary
    pads = np.concatenate((
        -corners[:2], corners[2:] - img.shape[:2]))
    npad = max(0, int(pads.max()))
    if npad > 0:
        img = cv2.copyMakeBorder(
            img, npad, npad, npad, npad,
            border_type, value=border_value)

    # crop image patch
    corners = (corners + npad).astype(int)
    patch = img[corners[0]:corners[2], corners[1]:corners[3]]

    # resize to out_size
    patch = cv2.resize(patch, (out_size, out_size),
                       interpolation=interp)

    return patch

现在进入ops.py下的crop_and_resize,这个函数的作用就是crop一块以object为中心的,边长为size大小的patch(如下面的绿色虚线的正方形框),然后将其resize成out_size的大小:传入sizecenter计算出角落坐标形式的正方形patch,即 ( y m i n , x m i n , y m a x , x m a x ) (y_{min},x_{min},y_{max},x_{max}) (ymin,xmin,ymax,xmax),如下面两个点粗的绿点,因为这样扩大的坐标有可能会超出原来的图片(如粉红色线所表示),所以就要计算左上角和右下角相对原图片超出多少,好进行pad,上面13-14行代码就是干这事。然后根据他们超出当中的最大值npad来在原图像周围pad(也就是最外面的蓝框),因为原图像增大了,所以corner相对坐标也变了,对应上面22行代码。然后就是crop下来绿色方形的那块,经过resize后返回。
关于_crop函数的图示
下面就是我的实验结果:
原图与抠下来的patch图对比

2.2 Pair

现在继续回到train_over方法,里面构造dataset的时候用了Pair类,所以从代码角度具体来看一下,因为继承了Dataset类,所以要overwrite __getitem____len__方法:

  • __getitem__:分析代码,这个方法就是通过index索引返回item = (z, x, box_z, box_x),然后经过transforms返回一对pair(z, x),就需要像论文里面说的:The images are extracted from two frames of a video that both contain the object and are at most T frames apart
    • _filter:通过该函数筛选符合条件的有效索引val_indices,这里不详细分析,因为我也不知道为什么会有这样的filter condition。
    • _sample_pair:如果有效索引大于2个的话,就从中随机挑选两个索引,这里取的间隔不超过T=100
  • __len__:这里定义的长度就是被索引到的视频序列帧数×每个序列提供的对数

2.3 train_step

现在来到siamfc.py里面最后一个关键的地方,数据准备好了,经过变换和加载进来就可以训练了,下面代码是常规操作,具体在train_step里面实现了训练和反向传播:

for epoch in range(self.cfg.epoch_num):
    # update lr at each epoch
    self.lr_scheduler.step(epoch=epoch)

    # loop over dataloader
    for it, batch in enumerate(dataloader):
        loss = self.train_step(batch, backward=True)
        print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
            epoch + 1, it + 1, len(dataloader), loss))
        sys.stdout.flush()

train_step里面难度又是在于理解_create_labels,具体的一些tensor的shape可以看我的注释,我好奇就把他打印出来了,看来本来__getitem__返回一对pair(z, x),经过dataloader的加载,还是z堆叠一起,x堆叠一起,并不是(z, x)绑定堆叠一起【主要自己对dataloader源码不是很熟,手动捂脸】
而且criterion使用的BalancedLoss,是调用F.binary_cross_entropy_with_logits,进行一个element-wise的交叉熵计算,所以创建出来的labels的shape其实就是和responses的shape是一样的:

def train_step(self, batch, backward=True):
    # set network mode
    self.net.train(backward)

    # parse batch data
    z = batch[0].to(self.device, non_blocking=self.cuda)
    x = batch[1].to(self.device, non_blocking=self.cuda)
    # print("batch_z shape:", z.shape) # torch.Size([8, 3, 127, 127])
    # print("batch_x shape:", x.shape) # torch.Size([8, 3, 239, 239])

    with torch.set_grad_enabled(backward):
        # inference
        responses = self.net(z, x)
        # print("responses shape:", responses.shape) # torch.Size([8, 1, 15, 15])

        # calculate loss
        labels = self._create_labels(responses.size())
        loss = self.criterion(responses, labels)

        if backward:
            # back propagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

创建标签,论文里是这么说的:
y [ u ] = { + 1  if  k ∥ u − c ∥ ≤ R − 1  otherwise  y[u]=\left\{\begin{array}{ll} {+1} & {\text { if } k\|u-c\| \leq R} \\ {-1} & {\text { otherwise }} \end{array}\right. y[u]={
+11 if kucR otherwise 

因为我们的exemplar image z z z 和search image x x x都是以目标为中心的,所以labels的中心为1,中心以外为0。

def _create_labels(self, size):
    # skip if same sized labels already created
    if hasattr(self, 'labels') and self.labels.size() == size:
        return self.labels

    def logistic_labels(x, y, r_pos, r_neg):
        dist = np.abs(x) + np.abs(y)  # block distance
        labels = np.where(dist <= r_pos,
                          np.ones_like(x),
                          np.where(dist < r_neg,
                                   np.ones_like(x) * 0.5,
                                   np.zeros_like(x)))
        return labels

    # distances along x- and y-axis
    n, c, h, w = size
    x = np.arange(w) - (w - 1) / 2
    y = np.arange(h) - (h - 1) / 2
    x, y = np.meshgrid(x, y)

    # create logistic labels 这里除以stride,是相对score map上来说
    r_pos = self.cfg.r_pos / self.cfg.total_stride
    r_neg = self.cfg.r_neg / self.cfg.total_stride
    labels = logistic_labels(x, y, r_pos, r_neg)

    # repeat to size
    labels = labels.reshape((1, 1, h, w))
    labels = np.tile(labels, (n, c, 1, 1))

    # convert to tensors
    self.labels = torch.from_numpy(labels).to(self.device).float()

    return self.labels

其中关于np.tile、np.meshgrid、np.where函数的使用可以去看我这篇博客,最后出来的一个batch下某一个通道下的label就是下面这样的,有没有一种扫雷的既视感,?:
labels的创建
至此此份repo的训练应该差不多结束了,测试部分(inference phase)我还没怎么看,且涉及到GOT-10k的使用,下一次有空再看再写~

上下篇

上一篇:siamfc-pytorch代码讲解(一):backbone&head
下一篇:siamfc-pytorch代码讲解(三):demo&track

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/187096.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • Python表白代码合集:5种表白代码,找不到对象你来找我,这也太秀了叭

    Python表白代码合集:5种表白代码,找不到对象你来找我,这也太秀了叭文章目录一、容我啰嗦两句二、来吧,代码展示1、给女神比个小心心2、无限弹窗式表白3、这货不是表白代码,悄悄送给你们4、520表白墙5、抖音热门表白小软件6、无套路表白三、写在最后一、容我啰嗦两句爬虫看多了,对身体不好,我们来点现实的,学学表白找个女朋友他不香吗,对吧~文章最后教你们怎么打包成exe,如果你懒得搞懂代码怎么回事,直接复制代码打包成exe运行就好了。这样不管你发给别人也好,以后方便直接用也好,都很方便。咱就不整什么鸡皮疙瘩掉一地的情话啥的了,有需要的自行百度。二、来吧,代码展示

  • 国内免费php mysql空间,[php mysql]国内有什么好的免备案免费php+mysql空间

    国内免费php mysql空间,[php mysql]国内有什么好的免备案免费php+mysql空间国内有什么好的免备案免费php+mysql空间问题补充:稳定点的,速度过得去就可以.谢谢●我一直在用”主机屋”提供的免费空间.稳定性可以运作几年了.不用备案.用了马上就知道好.百度搜索”主机屋”php+mysql实现无限级分类问题补充:php+mysql实现无限级分类●项目思路分析:一个PHP项目要用到分类,但不确定分几级,所以就想做成无限级分类。一开始想是按以前一样,数据库建4个值,如下:…

  • 复合主键与联合主键[通俗易懂]

    复合主键与联合主键[通俗易懂]一、复合主键 所谓的复合主键就是指你表的主键含有一个以上的字段组成,不使用无业务含义的自增id作为主键。比如 createtabletest(namevarchar(19),idnumber,valuevarchar(10),primarykey(name,id))上面的name和id字段组合起来就是你

  • PC上安装多个操作系统

    目 录 第1章 绪论1 1.1目标1 1.2适宜的读者1 第2章 制作启动U盘2 2.1初级安装2 2.2启动分析3 2.3高级安装13 2.3.1分区13 2.3.2复制文件16 2.3.3…

  • 最小二乘法正规方程推导过程

    最小二乘法正规方程推导过程最小二乘法正规方程推导过程线性回归岭回归:添加L2L_2L2​正则项输入样本X∈Rm×n\textbf{X}\in\mathbb{R}^{m\timesn}X∈Rm×n,输出y∈Rm×1\textbf{y}\in\mathbb{R}^{m\times1}y∈Rm×1,需要学习的参数w∈Rn×1\textbf{w}\in\mathbb{R}^{n\times1}w∈Rn×1。其中,mmm为样本个数,nnn为单个样本维度。线性回归最小化目标函数J(w)=12∥y−Xw∥22J(\

  • matlab保存图片函数后突变分辨变化,MATLAB总结 – 图片保存「建议收藏」

    matlab保存图片函数后突变分辨变化,MATLAB总结 – 图片保存「建议收藏」I.Matlab中保存图片的方法1.一种是出来图形窗口后手动保存(这儿又可以分两种):1.1直接从菜单保存,有fig,eps,jpeg,gif,png,bmp等格式。1.2edit——〉copyfigure,再粘贴到其他程序。2.另一种是用命令直接保存(这里也有两种):2.1用saveas命令保存图片。saveas的三个参数:(1)图形句柄,如果图形窗口标题栏是“Figure3…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号