# -*- coding: utf-8 -*-
Spatial Transformer Networks Tutorial
**Author**: `Ghassen HAMROUNI <>`_
.. figure:: /_static/img/stn/FSeq.png
In this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the `DeepMind paper <>`__
Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
# License: BSD
# Author: Ghassen Hamrouni
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
from tensorboardX import SummaryWriter
#from logger import Logger
#logger = Logger('./logs')
plt.ion()   # interactive mode
# Loading the data
# ----------------
# In this post we experiment with the classic MNIST dataset. Using a
# standard convolutional network augmented with a spatial transformer
# network.
use_cuda = torch.cuda.is_available()
# Training dataset
train_loader =
datasets.MNIST(root='.', train=True, download=True,
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=128, shuffle=True, num_workers=4)
# Test dataset
test_loader =
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Depicting spatial transformer networks
# --------------------------------------
# Spatial transformer networks boils down to three main components :
# -  The localization network is a regular CNN which regresses the
#    transformation parameters. The transformation is never learned
#    explicitly from this dataset, instead the network learns automatically
#    the spatial transformations that enhances the global accuracy.
# -  The grid generator generates a grid of coordinates in the input
#    image corresponding to each pixel from the output image.
# -  The sampler uses the parameters of the transformation and applies
#    it to the input image.
# .. figure:: /_static/img/stn/stn-arch.png
# .. Note::
#    We need the latest version of PyTorch that contains
#    affine_grid and grid_sample modules.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.Linear(32, 3 * 2)
# Initialize the weights/bias with identity transformation
self.fc_loc[2] = torch.FloatTensor([1, 0, 0, 0, 1, 0])
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)  #equal to reshape()  64x3x3x10-> 64*90
theta = self.fc_loc(xs)   #64x6
theta = theta.view(-1, 2, 3)  #reshape 64x2x3 Transform matrix
grid = F.affine_grid(theta, x.size())
#theta (Variable): input batch of affine matrices (N x 2 x 3)  64 x 2 x 3
#size (torch.Size): the target output image size (N x C x H x W)  64x1x28x28
#output (Variable): output Tensor of size (N x H x W x 2)         64x28x28x2
x = F.grid_sample(x, grid)
input (Variable): input batch of images (N x C x IH x IW)
grid (Variable): flow-field of size (N x OH x OW x 2)
padding_mode (str): padding mode for outside grid values
'zeros' | 'border'. Default: 'zeros'
output: N x OH x OW x C
return x  #64 x 28 x28 x 1
def forward(self, x):  #x: 64 x28 x 28 x 1
# transform the input
x = self.stn(x)   # 64 x 28 x 28 x 1
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 64 x 12 x 12 x 10
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #64 x 4 x 4 x 20
x = x.view(-1, 320)  #64x320
x = F.relu(self.fc1(x)) #64x50
x = F.dropout(x,
x = self.fc2(x)  #64x10
return F.log_softmax(x, dim=1)
model = Net()
if use_cuda:
# Training the model
# ------------------
# Now, let's use the SGD algorithm to train the model. The network is
# learning the classification task in a supervised way. In the same time
# the model is learning STN automatically in an end-to-end fashion.
optimizer = optim.SGD(model.parameters(), lr=0.01)
def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
if use_cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target) #定义为Variable类型,能够调用autograd
output = model(data)
loss = F.nll_loss(output, target)
#if batch_idx % args.log_interval == 0:????????
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),[0]))
# A simple test procedure to measure STN the performances on MNIST.
def test():
model.eval()   #让模型变为测试模式,主要是保证dropout和BN和训练过程一致。BN是指batch normalization
test_loss = 0
correct = 0
for data, target in test_loader:
if use_cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred =, keepdim=True)[1]     #获得得分最高的类别
correct += pred.eq(
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
将一个data pass分成几个mini-batch
每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
进行一次data pass之后,就可以将每一个mini-batch average loss求和
将loss之和再除以mini-batch的数量,就得到最后的data point average loss
所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。
更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:
for each mini-batch:
test_loss += F.nll_loss(output, target, size_average=False).data[0]
test_loss /= len(test_loader.dataset)'''
# Visualizing the STN results
# ---------------------------
# Now, we will inspect the results of our learned visual attention
# mechanism.
# We define a small helper function in order to visualize the
# transformations while training.
def convert_image_np(inp):
"""Convert a Tensor to numpy image."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp
# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.
def visualize_stn():
# Get a batch of training data
data, _ = next(iter(test_loader))
data = Variable(data)  #修改, volatile=True
if use_cuda:
data = data.cuda()
input_tensor = data.cpu().data
transformed_input_tensor = model.stn(data).cpu().data
in_grid = convert_image_np(
out_grid = convert_image_np(
# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].set_title('Dataset Images')
axarr[1].set_title('Transformed Images')
#for epoch in range(1, 20 + 1):  for epoch in range(1, args.epochs + 1):
for epoch in range(1, 4 + 1):
# Visualize the STN transformation on some input batch
print("Total time= {:.3f}s".format(tot_time))
