Spatial Transformer Network_transgression

Spatial Transformer Network_transgression导读上一篇通俗易懂的SpatialTransformerNetworks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorch和numpy来实现了他们,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块插入到CNN中STN关键点解读STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖θ\thetaθ参数来实现的。刚开

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

导读

上一篇通俗易懂的Spatial Transformer Networks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorchnumpy来实现了,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块嵌入到CNN中

STN关键点解读

STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖 θ \theta θ参数来实现的。刚开始的时候我还以为训练STN还需要准备 θ \theta θ标签数据,实际上并不需要。

当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到CNN网络中,通过损失函数计算loss,然后计算梯度更新 θ \theta θ参数,最终STN模块会学习到如何矫正图片。

代码实现

  • 导包
import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import numpy as np
from torchsummary import summary
import argparse
  • 定义网络结构
class STN_Net(nn.Module):
def __init__(self,use_stn=True):
super(STN_Net, self).__init__()
self.conv1 = nn.Conv2d(1,10,kernel_size=5)
self.conv2 = nn.Conv2d(10,20,kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320,50)
self.fc2 = nn.Linear(50,10)
#用来判断是否使用STN
self._use_stn = use_stn
#localisation net
#从输入图像中提取特征
#输入图片的shape为(-1,1,28,28)
self.localization = nn.Sequential(
#卷积输出shape为(-1,8,22,22)
nn.Conv2d(1,8,kernel_size=7),
#最大池化输出shape为(-1,1,11,11)
nn.MaxPool2d(2,stride=2),
nn.ReLU(True),
#卷积输出shape为(-1,10,7,7)
nn.Conv2d(8,10,kernel_size=5),
#最大池化层输出shape为(-1,10,3,3)
nn.MaxPool2d(2,stride=2),
nn.ReLU(True)
)
#利用全连接层回归\theta参数
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3,32),
nn.ReLU(True),
nn.Linear(32,2*3)
)
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0]
,dtype=torch.float))
def stn(self,x):
#提取输入图像中的特征
xs = self.localization(x)
xs = xs.view(-1,10*3*3)
#回归theta参数
theta = self.fc_loc(xs)
theta = theta.view(-1,2,3)
#利用theta参数计算变换后图片的位置
grid = F.affine_grid(theta,x.size())
#根据输入图片计算变换后图片位置填充的像素值
x = F.grid_sample(x,grid)
return x
def forward(self,x):
#使用STN模块
if self._use_stn:
x = self.stn(x)
#利用STN矫正过的图片来进行图片的分类
#经过conv1卷积输出的shape为(-1,10,24,24)
#经过max pool的输出shape为(-1,10,12,12)
x = F.relu(F.max_pool2d(self.conv1(x),2))
#经过conv2卷积输出的shape为(-1,20,8,8)
#经过max pool的输出shape为(-1,20,4,4)
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
x = x.view(-1,320)
x = F.relu(self.fc1(x))
x = F.dropout(x,training=self.training)
x = self.fc2(x)
return F.log_softmax(x,dim=1)
  • 加载数据集
def get_dataloader(batch_size):
# 加载数据集
# 如果GPU可用就用GPU,否则用CPU
device = torch.device("cuda" if torch.cuda.is_available()
else "cpu")
# 加载训练集
train_dataloader = torch.utils.data.DataLoader(
datasets.MNIST(root="D:/dataset", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size, shuffle=True)
# 加载测试集
test_dataloader = torch.utils.data.DataLoader(
datasets.MNIST(root="D:/dataset", train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size, shuffle=True)
return train_dataloader,test_dataloader
  • 训练模型
def train(net,epoch_nums,lr,train_dataloader,per_batch,device):
#使用训练模式
net.train()
#选择梯度下降优化算法
optimizer = optim.SGD(net.parameters(),lr=lr)
#训练模型
for epoch in range(epoch_nums):
for batch_idx,(data,label) in enumerate(train_dataloader):
data,label = data.to(device),label.to(device)
optimizer.zero_grad()
pred = net(data)
loss = F.nll_loss(pred,label)
loss.backward()
optimizer.step()
if batch_idx % per_batch == 0:
print("Train Epoch:{ 
} [{ 
}/{ 
} ({ 
:.0f}%)]\tLoss:
{ 
:.6f}".format(epoch,batch_idx * len(data),
len(train_dataloader.dataset),
100. * batch_idx /len(train_dataloader),loss.item()))
  • 评估模型
def evaluate(net,test_dataloader,device):
with torch.no_grad():
#使用评估模式
net.eval()
eval_loss = 0
eval_acc = 0
for data,label in test_dataloader:
data,label = data.to(device),label.to(device)
pred = net(data)
eval_loss += F.nll_loss(pred,label,
size_average=False).item()
pred_label = pred.max(1,keepdim=True)[1]
eval_acc += pred_label.eq(label.view_as(pred_label)
).sum().item()
eval_loss /= len(test_dataloader.dataset)
print("evaluate set: Average loss: { 
:.4f},Accuracy:{ 
}/{ 
} 
({ 
:.2f}%)\n".format(
eval_loss,eval_acc,len(test_dataloader.dataset),
100*eval_acc / len(test_dataloader.dataset)))
  • 将pytorch的tensor转换为numpy的array
def tensor_to_array(img_tensor):
img_array = img_tensor.numpy().transpose((1,2,0))
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])
img_array = std * img_array + mean
img = np.clip(img_array,0,1)
return img
  • 可视化STN变换图片
def visualize_stn(net,dataloader,device):
with torch.no_grad():
data = next(iter(dataloader))[0].to(device)
input_tensor = data.cpu()
t_input_tensor = net.stn(data).cpu()
in_grid = tensor_to_array(torchvision.utils.make_grid(
input_tensor))
out_grid = tensor_to_array(torchvision.utils.make_grid(
t_input_tensor))
f,axarr = plt.subplots(1,2)
axarr[0].imshow(in_grid)
axarr[0].set_title("input images")
axarr[1].imshow(out_grid)
axarr[1].set_title("stn transformed images")
plt.show()

在这里插入图片描述
通过对比输入图片和经过STN变换后的图片能够很明显发现,经过STN之后能将旋转的图片进行明显的纠正。

  • 参数设置
def parse_args():
parse = argparse.ArgumentParser("config stn args")
parse.add_argument("--lr",default=0.01,
type=float,help="learning rate")
parse.add_argument("--epoch_nums",default=20,
type=int,help="iterated epochs")
parse.add_argument("--use_stn",default=True,
type=bool,help="whether to use STN module")
parse.add_argument("--batch_size",default=64,
type=int,help="batch size")
parse.add_argument("--use_eval",default=True,
type=bool,help="whether to evaluate")
parse.add_argument("--use_visual",default=True,
type=bool,help="visual STN transform image")
parse.add_argument("--use_gpu",default=True,
type=bool,help="whether to use GPU")
parse.add_argument("--show_net_construct",default=False,
type=bool,help="print net construct info")
return parse.parse_args()
  • 主函数
if __name__ == "__main__":
args = parse_args()
if args.use_gpu and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
#加载数据集
train_loader,test_loader = get_dataloader(args.batch_size)
#创建网络
net = STN_Net(args.use_stn).to(device)
#打印网络的结构信息
if args.show_net_construct:
summary(net,(1,28,28))
#训练模型
train(net,args.epoch_nums,args.lr,train_loader
,args.batch_size,device)
if args.use_eval:
#评估模型
evaluate(net,test_loader,device)
if args.use_visual:
#可视化展示效果
visualize_stn(net,test_loader,device)

参考:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/183487.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 最大共识面临崩塌?比特币要增发?

    最大共识面临崩塌?比特币要增发?白话区块链从入门到精通,看我就够了!两天前,江卓尔的一条微博,一石激起千层浪。原文是这样的:比特币Core下一目标是增发比特币,修改其上限2100万,停止减半。没错,不要…

  • js判断是否包含指定字符串_js字符串包含字符串

    js判断是否包含指定字符串_js字符串包含字符串我是想在js中判断字符串是否包含某个中文,将方法记录起来,这些方法也适用于数字、字母。实践是检验真理的唯一标准,还是要多多测试啊。String对象的方法方法一:indexOf()vargroupName=”小白A组”;alert(‘groupName.indexOf()=’+(groupName.indexOf(“组”)!=-1));//trueindex…

  • 微服务架构技术有哪些_微服务架构组件

    微服务架构技术有哪些_微服务架构组件目录一、微服务架构实现需求二、微服务架构实现技术选型:参考标准的两个维度+微服务实现框架对比(一)技术选型的两个参考标准1.核心组件完备性2.关键要素实现难度(二)微服务实现框架对比SpringBoot/CloudDubbogRPC新锐微服务框架:Istio(ServiceMesh的设计理念)参考书籍、文献和资料:一、微服务架构实现需求技…

    2022年10月21日
  • 计算机操作系统-操作系统的定义

    计算机系统的层次结构 用户 应用程序 操作系统 纯硬件:CPU、RAM、ROM 其中,操作系统:从操作系统层往两侧看:负责管理协调硬件、软件等计算机资源的工作 从上往下看:为上层的应用程序和用户提供简单易用的服务 从下往上看:操作系统系统软件,而不是硬件定义OperatingSystem是指控制和管理整个计算机系统的硬件和软件资源,并合理地组…

  • IDEA2021.2安装与配置(持续更新)「建议收藏」

    IDEA2021.2安装与配置一、下载二、安装三、配置配置全局生效首次启动激活字体,字体大小配色方案注解生效自动导包移包自动补全快捷键格式化代码代码忽略大小写git配置maven配置四、插件Vue.jsTranslationlombok一、下载下载地址:https://www.jetbrains.com/zh-cn/idea/download/other.html选择相应的版本下载,这里以2021.2版本为例。二、安装更改安装位置创建桌面快捷方式三、配置配置全局生效不要打开项目,直

  • 其实Unix很简单

    很多编程的朋友都在网上问我这样的几个问题,Unix怎么学?Unix怎么这么难?如何才能学好?并且让我给他们一些学好Unix的经验。在绝大多数时候,我发现问这些问题的朋友都有两个特点:1)对Unix有

    2021年12月27日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号