Spatial Transformer Network_transgression

Spatial Transformer Network_transgression导读上一篇通俗易懂的SpatialTransformerNetworks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorch和numpy来实现了他们,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块插入到CNN中STN关键点解读STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖θ\thetaθ参数来实现的。刚开

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

导读

上一篇通俗易懂的Spatial Transformer Networks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorchnumpy来实现了,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块嵌入到CNN中

STN关键点解读

STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖 θ \theta θ参数来实现的。刚开始的时候我还以为训练STN还需要准备 θ \theta θ标签数据,实际上并不需要。

当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到CNN网络中,通过损失函数计算loss,然后计算梯度更新 θ \theta θ参数,最终STN模块会学习到如何矫正图片。

代码实现

  • 导包
import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import numpy as np
from torchsummary import summary
import argparse
  • 定义网络结构
class STN_Net(nn.Module):
    def __init__(self,use_stn=True):
        super(STN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)
        #用来判断是否使用STN
        self._use_stn = use_stn

        #localisation net
        #从输入图像中提取特征
        #输入图片的shape为(-1,1,28,28)
        self.localization = nn.Sequential(
            #卷积输出shape为(-1,8,22,22)
            nn.Conv2d(1,8,kernel_size=7),
            #最大池化输出shape为(-1,1,11,11)
            nn.MaxPool2d(2,stride=2),
            nn.ReLU(True),
            #卷积输出shape为(-1,10,7,7)
            nn.Conv2d(8,10,kernel_size=5),
            #最大池化层输出shape为(-1,10,3,3)
            nn.MaxPool2d(2,stride=2),
            nn.ReLU(True)
        )
        #利用全连接层回归\theta参数
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3,32),
            nn.ReLU(True),
            nn.Linear(32,2*3)
        )

        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0]
        ,dtype=torch.float))

    def stn(self,x):
        #提取输入图像中的特征
        xs = self.localization(x)
        xs = xs.view(-1,10*3*3)
        #回归theta参数
        theta = self.fc_loc(xs)
        theta = theta.view(-1,2,3)

        #利用theta参数计算变换后图片的位置
        grid = F.affine_grid(theta,x.size())
        #根据输入图片计算变换后图片位置填充的像素值
        x = F.grid_sample(x,grid)

        return x

    def forward(self,x):
        #使用STN模块
        if self._use_stn:
            x = self.stn(x)
        #利用STN矫正过的图片来进行图片的分类
        #经过conv1卷积输出的shape为(-1,10,24,24)
        #经过max pool的输出shape为(-1,10,12,12)
        x = F.relu(F.max_pool2d(self.conv1(x),2))
        #经过conv2卷积输出的shape为(-1,20,8,8)
        #经过max pool的输出shape为(-1,20,4,4)
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
        x = x.view(-1,320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x,dim=1)
  • 加载数据集
def get_dataloader(batch_size):
    # 加载数据集
    # 如果GPU可用就用GPU,否则用CPU
    device = torch.device("cuda" if torch.cuda.is_available()
    					   else "cpu")
    # 加载训练集
    train_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(root="D:/dataset", train=True, download=True,
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=True)

    # 加载测试集
    test_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(root="D:/dataset", train=False,
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=True)

    return train_dataloader,test_dataloader
  • 训练模型
def train(net,epoch_nums,lr,train_dataloader,per_batch,device):
    #使用训练模式
    net.train()
    #选择梯度下降优化算法
    optimizer = optim.SGD(net.parameters(),lr=lr)
    #训练模型
    for epoch in range(epoch_nums):
        for batch_idx,(data,label) in enumerate(train_dataloader):
            data,label = data.to(device),label.to(device)

            optimizer.zero_grad()
            pred = net(data)
            loss = F.nll_loss(pred,label)
            loss.backward()
            optimizer.step()

            if batch_idx % per_batch == 0:
                print("Train Epoch:{ 
   } [{ 
   }/{ 
   } ({ 
   :.0f}%)]\tLoss:
                { 
   :.6f}".format(epoch,batch_idx * len(data),
                len(train_dataloader.dataset),
                100. * batch_idx /len(train_dataloader),loss.item()))
  • 评估模型
def evaluate(net,test_dataloader,device):
    with torch.no_grad():
        #使用评估模式
        net.eval()
        eval_loss = 0
        eval_acc = 0
        for data,label in test_dataloader:
            data,label = data.to(device),label.to(device)
            pred = net(data)

            eval_loss += F.nll_loss(pred,label,
            size_average=False).item()
            pred_label = pred.max(1,keepdim=True)[1]
            eval_acc += pred_label.eq(label.view_as(pred_label)
            ).sum().item()

        eval_loss /= len(test_dataloader.dataset)
        print("evaluate set: Average loss: { 
   :.4f},Accuracy:{ 
   }/{ 
   } 
        ({ 
   :.2f}%)\n".format(
            eval_loss,eval_acc,len(test_dataloader.dataset),
            100*eval_acc / len(test_dataloader.dataset)))
  • 将pytorch的tensor转换为numpy的array
def tensor_to_array(img_tensor):
    img_array = img_tensor.numpy().transpose((1,2,0))
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img_array = std * img_array + mean
    img = np.clip(img_array,0,1)
    return img
  • 可视化STN变换图片
def visualize_stn(net,dataloader,device):
    with torch.no_grad():
        data = next(iter(dataloader))[0].to(device)

        input_tensor = data.cpu()
        t_input_tensor = net.stn(data).cpu()

        in_grid = tensor_to_array(torchvision.utils.make_grid(
        input_tensor))
        out_grid = tensor_to_array(torchvision.utils.make_grid(
        t_input_tensor))

        f,axarr = plt.subplots(1,2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title("input images")

        axarr[1].imshow(out_grid)
        axarr[1].set_title("stn transformed images")

        plt.show()

在这里插入图片描述
通过对比输入图片和经过STN变换后的图片能够很明显发现,经过STN之后能将旋转的图片进行明显的纠正。

  • 参数设置
def parse_args():
    parse = argparse.ArgumentParser("config stn args")
    parse.add_argument("--lr",default=0.01,
    type=float,help="learning rate")
    parse.add_argument("--epoch_nums",default=20,
    type=int,help="iterated epochs")
    parse.add_argument("--use_stn",default=True,
    type=bool,help="whether to use STN module")
    parse.add_argument("--batch_size",default=64,
    type=int,help="batch size")
    parse.add_argument("--use_eval",default=True,
    type=bool,help="whether to evaluate")
    parse.add_argument("--use_visual",default=True,
    type=bool,help="visual STN transform image")
    parse.add_argument("--use_gpu",default=True,
    type=bool,help="whether to use GPU")
    parse.add_argument("--show_net_construct",default=False,
    type=bool,help="print net construct info")
    return parse.parse_args()
  • 主函数
if __name__ == "__main__":
    args = parse_args()
    if args.use_gpu and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    #加载数据集
    train_loader,test_loader = get_dataloader(args.batch_size)
    #创建网络
    net = STN_Net(args.use_stn).to(device)
    #打印网络的结构信息
    if args.show_net_construct:
        summary(net,(1,28,28))
    #训练模型
    train(net,args.epoch_nums,args.lr,train_loader
    ,args.batch_size,device)
    if args.use_eval:
        #评估模型
        evaluate(net,test_loader,device)
    if args.use_visual:
        #可视化展示效果
        visualize_stn(net,test_loader,device)

参考:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/183487.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • Matlab fmincon[通俗易懂]

    Matlab fmincon[通俗易懂]fmincon的官方文档fmincon非线性优化

  • https和http有哪些区别「建议收藏」

    https和http有哪些区别「建议收藏」什么是HTTPS?HTTPS的全称是超文本传输安全协议(HypertextTransferProtocolSecure),是一种网络安全传输协议。在HTTP的基础上加入SSL/TLS来进行数据加密,保护交换数据不被泄露、窃取。通俗的来说,就是:当你登陆一个有网站的网页时形成,在填写该表格并点击“提交”后,您输入的信息可能被黑客截获不安全网站。这些信息可以是银行交易的详细信…

    2022年10月17日
  • Deepin安装MariaDB数据库

    Deepin安装MariaDB数据库在deeping上安装mariadb 1,安装的官网参考:有安装的命令和指导https://downloads.mariadb.org/mariadb/repositories/#distro=Debian&distro_release=jessie–jessie&mirror=tuna&version=10.4其主官网地址:https://downloads.mariadb.org/  2,安装mari…

  • 【android】在eclipse中查看genymotion模拟器的sd卡文件夹

    【android】在eclipse中查看genymotion模拟器的sd卡文件夹

  • 【C#基础】-Substring截取字符串的方法小结

    【C#基础】-Substring截取字符串的方法小结前言    在公司的图书馆项目中曾经用过截取字符串的方法,项目是java语言的;最近在公司的另一个项目中又需要截取字符串,一种环境是C#语言,一种环境是SQLServer存储过程;先来说一下后台程序中截取字符串的方法。正文c#中截取字符串主要是借助Substring这个函数。stringstring.Substring(intstartIndex,intlength)

  • keypad(键盘矩阵)指南

    keypad(键盘矩阵)指南目录keyPad简介API说明示例常见问题相关资料以及开发板购买链接keyPad简介Air724UG支持6X6键盘矩阵,可以在luat二次开发的方式应用,但注意AT版本不支持键盘功能。API说明API接口描述powerKey.setup(longPrd,longCb,shortCb)开机键功能配置常用api_1介绍常用api_2介绍示例1.创建一个tKeypad表,储存所有按键值(16个键盘元素+1个开关机键元素)–每个元素的索引为行列值拼接而成的字符

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号