stn  pytorch[通俗易懂]

stn  pytorch[通俗易懂]#-*-coding:utf-8-*-"""SpatialTransformerNetworksTutorial=====================================**Author**:`GhassenHAMROUNI<https://github.com/GHamrouni>`_..figure::/_static/img/…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

# -*- coding: utf-8 -*-
"""
Spatial Transformer Networks Tutorial
=====================================
**Author**: `Ghassen HAMROUNI <https://github.com/GHamrouni>`_

.. figure:: /_static/img/stn/FSeq.png

In this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the `DeepMind paper <https://arxiv.org/abs/1506.02025>`__

Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
transformations.

One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
"""
# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np


from tensorboardX import SummaryWriter
#from logger import Logger
#logger = Logger('./logs')

plt.ion()   # interactive mode

######################################################################
# Loading the data
# ----------------
#
# In this post we experiment with the classic MNIST dataset. Using a
# standard convolutional network augmented with a spatial transformer
# network.

use_cuda = torch.cuda.is_available()

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=128, shuffle=True, num_workers=4)


# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

######################################################################
# Depicting spatial transformer networks
# --------------------------------------
#
# Spatial transformer networks boils down to three main components :
#
# -  The localization network is a regular CNN which regresses the
#    transformation parameters. The transformation is never learned
#    explicitly from this dataset, instead the network learns automatically
#    the spatial transformations that enhances the global accuracy.
# -  The grid generator generates a grid of coordinates in the input
#    image corresponding to each pixel from the output image.
# -  The sampler uses the parameters of the transformation and applies
#    it to the input image.
#
# .. figure:: /_static/img/stn/stn-arch.png
#
# .. Note::
#    We need the latest version of PyTorch that contains
#    affine_grid and grid_sample modules.
#


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.fill_(0)
        self.fc_loc[2].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)  #equal to reshape()  64x3x3x10-> 64*90
        theta = self.fc_loc(xs)   #64x6
        theta = theta.view(-1, 2, 3)  #reshape 64x2x3 Transform matrix

        grid = F.affine_grid(theta, x.size())
        #theta (Variable): input batch of affine matrices (N x 2 x 3)  64 x 2 x 3
        #size (torch.Size): the target output image size (N x C x H x W)  64x1x28x28
        #output (Variable): output Tensor of size (N x H x W x 2)         64x28x28x2

        x = F.grid_sample(x, grid)
        '''
        Args:
        input (Variable): input batch of images (N x C x IH x IW)
        grid (Variable): flow-field of size (N x OH x OW x 2)
        padding_mode (str): padding mode for outside grid values
            'zeros' | 'border'. Default: 'zeros'
        output: N x OH x OW x C
        '''
        
        return x  #64 x 28 x28 x 1

    def forward(self, x):  #x: 64 x28 x 28 x 1
        # transform the input
        x = self.stn(x)   # 64 x 28 x 28 x 1

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 64 x 12 x 12 x 10
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #64 x 4 x 4 x 20
        x = x.view(-1, 320)  #64x320
        x = F.relu(self.fc1(x)) #64x50
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)  #64x10
        return F.log_softmax(x, dim=1)


model = Net()
if use_cuda:
    model.cuda()

######################################################################
# Training the model
# ------------------
#
# Now, let's use the SGD algorithm to train the model. The network is
# learning the classification task in a supervised way. In the same time
# the model is learning STN automatically in an end-to-end fashion.
'''现在,我们使用SGD算法来训练模型。网络以监督的方式学习分类任务。与此同时,该模型以端到端的方式自动学习STN。'''

optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    #调用钱箱传播
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda(), target.cuda()

        data, target = Variable(data), Variable(target) #定义为Variable类型,能够调用autograd
        optimizer.zero_grad()#初始化时,要清空梯度
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()#相当于更新权重值
        '''
    在实现梯度反向传递时主要需要三步:
    初始化梯度值:net.zero_grad()
    反向求解梯度:loss.backward()
    更新参数:optimizer.step()
'''


#if batch_idx % args.log_interval == 0:????????
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
#
# A simple test procedure to measure STN the performances on MNIST.
#


def test():
    model.eval()   #让模型变为测试模式,主要是保证dropout和BN和训练过程一致。BN是指batch normalization
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)

        # sum up batch loss
        test_loss += F.nll_loss(output, target, size_average=False).item()
        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=True)[1]     #获得得分最高的类别
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
          .format(test_loss, correct, len(test_loader.dataset),
                  100. * correct / len(test_loader.dataset)))
'''
    将一个data pass分成几个mini-batch
    每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
    进行一次data pass之后,就可以将每一个mini-batch average loss求和
    将loss之和再除以mini-batch的数量,就得到最后的data point average loss

所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。

更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:

for each mini-batch:
    ...
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    ...
...
test_loss /= len(test_loader.dataset)'''
tot_time=0;

######################################################################
# Visualizing the STN results
# ---------------------------
#
# Now, we will inspect the results of our learned visual attention
# mechanism.
#
# We define a small helper function in order to visualize the
# transformations while training.


def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.


def visualize_stn():
    # Get a batch of training data
    data, _ = next(iter(test_loader))
    data = Variable(data)  #修改, volatile=True

    if use_cuda:
        data = data.cuda()

    input_tensor = data.cpu().data
    transformed_input_tensor = model.stn(data).cpu().data

    in_grid = convert_image_np(
        torchvision.utils.make_grid(input_tensor))

    out_grid = convert_image_np(
        torchvision.utils.make_grid(transformed_input_tensor))

    # Plot the results side-by-side
    f, axarr = plt.subplots(1, 2)
    axarr[0].imshow(in_grid)
    axarr[0].set_title('Dataset Images')

    axarr[1].imshow(out_grid)
    axarr[1].set_title('Transformed Images')


#for epoch in range(1, 20 + 1):  for epoch in range(1, args.epochs + 1):
for epoch in range(1, 4 + 1):
    train(epoch)
    test()

# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()
print("Total time= {:.3f}s".format(tot_time))
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/180403.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 11款网站死链检测工具[通俗易懂]

    11款网站死链检测工具[通俗易懂]11款网站死链检测工具发表于2013-12-1317:29|1751次阅读|来源Webdesigntalks|8条评论|作者JasonSmith工具网站链接摘要:网站中出现断链或死链,会导致部分资源无法访问,出现404报错,影响SEO。我们可以通过下面的11款链接检测工具(包括在线检测工具)帮助我们检查网站失效链接的情况。网站中出现断链或死链,会导致部分资

  • java解析xml element_java解析XML Node与Element的区别(推荐)「建议收藏」

    java解析xml element_java解析XML Node与Element的区别(推荐)「建议收藏」对Element和Node有困惑是因为对xml整个结构不了解,以下作为一个简要概述:以下图为w3c.org网站的xml文档树图:从上图可以看出,一个xml文档由元素节点,属性节点,文本节点构成,其中bookstore被称为文档元素或根元素,也是一个元素节点XMLDOM是这样规定一个节点的XML文档中的每个成分都是一个节点。整个文档是一个文档节点即Document节点。在java中Doc…

  • python实现守护进程_linux 守护进程

    python实现守护进程_linux 守护进程什么是守护线程:在后台运行,为其他线程提供服务的线程成为守护线程。为什么要引入守护线程:thread模块不支持守护线程的概念,当主线程退出时,所有的子线程都将终止,不管它们是否仍在工作,如果你不希望发生这种行为,就要引入守护线程的概念。守护线程的调用格式:thread.setDaemon(True)如何使用守护线程:1.当只有一个子线程并且为守护线程,那么这个守护线程就会等待主线程运行完毕后挂掉2…

  • 如何设置pycharm字体大小_pycharm设置字体颜色

    如何设置pycharm字体大小_pycharm设置字体颜色pycharm是很好的一个IDE,在windows下,和macOS下,都能很好的运行。唯一缺点是启动慢。默认字体太小,在mac下,需要瞪大24K氪金狗眼才能看清。为了保护好眼睛,我们需要把字体调整大一些: 步骤:Settings–&gt;Editor–&gt;ColorsScheme–&gt;ColorSchemeFont然后在size那里调整。Linespa…

  • pycharm 打包发布程序_pycharm打包成可执行文件

    pycharm 打包发布程序_pycharm打包成可执行文件安装依赖包pipinstallpyinstaller打包指令#到指定文件,终端输入,不带控制台的打包Pyinstaller-F-wmain.py不能显示print信息#到指定文件,终端输入Pyinstaller-Fmain.py去除-w可以显示print信息针对弹出的控制台闪退importosos.system(“pause”)…

    2022年10月26日
  • 计算机网络复习题

    计算机网络复习题1.计算机网络从逻辑功能上可分成(通信子网)和(资源子网)。2.网络的低层模块向高层提供功能性的支持称之为(服务)。3.TCP/1P标准模型从高到低分别是(应用层)、(运输层)、(网络层)、(链路层)和(物理层)。4.在一个物理信道内可以传输频率的范围称为(带宽)。5.计算机网络依据网络传输技术来分类,主要有(广播式)和(点对点式)两类。6.通信双方在进行通信之前,需要事先建立一个可以彼此沟通的的通道,这种通信方式称为(面向连接)的通信。7.因特网上提供的主要信息服务有(电子邮件)、(WWW

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号