大家好,又见面了,我是你们的朋友全栈君。
与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考
首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:
1、首先是pytorch版残差模块
class ResidualBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
super(ResidualBlock, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel, outchannel, 3, stride, 1),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, 3, 1, 1),
nn.BatchNorm2d(outchannel)
)
self.right = shortcut
def forward(self, x):
out = self.left(x)
residual = x if self.right is None else self.right(x)
out += residual
return F.relu(out)
2、libtorch版残差模块
因为是用c++搭建的,所以先创建头文件
2.1残差模块头文件(声明)
//重载函数
inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_bias = true) {
torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
conv_options.stride(stride);
conv_options.padding(padding);
conv_options.bias(with_bias);
conv_options.groups(groups);
return conv_options;
}
//残差模块声明
class Block_ocrImpl : public torch::nn::Module {
public:
Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
torch::Tensor forward(torch::Tensor x);
torch::nn::Sequential downsample{
nullptr };
private:
bool is_basic = true;
int64_t stride = 1;
torch::nn::Conv2d conv1{
nullptr };
torch::nn::BatchNorm2d bn1{
nullptr };
torch::nn::Conv2d conv2{
nullptr };
torch::nn::BatchNorm2d bn2{
nullptr };
torch::nn::Conv2d conv3{
nullptr };
torch::nn::BatchNorm2d bn3{
nullptr };
};
TORCH_MODULE(Block_ocr);
2.2残差模块定义
这里我们要在头文件里面写一个卷积的重载函数,省去以后重复写的工作,我把它放在了2的头文件里面
//残差模块定义
Block_ocrImpl::Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_,
torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{
downsample = downsample_;
stride = stride_;
int width = int(planes * (base_width / 64.)) * groups;
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
is_basic = _is_basic;
if (!is_basic) {
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
}
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("conv2", conv2);
register_module("bn2", bn2);
if (!is_basic) {
register_module("conv3", conv3);
register_module("bn3", bn3);
}
if (!downsample->is_empty()) {
register_module("downsample", downsample);
}
}
//残差前向传播
torch::Tensor Block_ocrImpl::forward(torch::Tensor x) {
torch::Tensor residual = x.clone();
x = conv1->forward(x);
x = bn1->forward(x);
x = torch::relu(x);
x = conv2->forward(x);
x = bn2->forward(x);
if (!is_basic) {
x = torch::relu(x);
x = conv3->forward(x);
x = bn3->forward(x);
}
if (!downsample->is_empty()) {
residual = downsample->forward(residual);
}
x += residual;
x = torch::relu(x);
return x;
}
3、pytorch版resnet主函数
class ResNet18(nn.Module):
def __init__(self,nc):
super(ResNet18, self).__init__()
###网络输入部分由一个7x7stride=2的卷积核和一个3x3stride=2的最大池化组成
self.pre = nn.Sequential(
nn.Conv2d(nc, 64, 7, 2, 3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1),
)
###网络中间卷积部分,通过中间3x3的卷积堆叠来实现信息的提取,下面的2代表bolck的重复堆叠次数
self.layer1 = self._make_layer(64, 128, 1)
self.layer2 = self._make_layer(128, 256, 2, stride=(2, 1))
self.layer3 = self._make_layer(256, 512, 5, stride=(2, 1))
self.layer4 = self._make_layer(512, 512, 3, stride=(2, 1))
def _make_layer(self, inchannel, outchannel, block_num, stride=(1, 1)):
shortcut = nn.Sequential(
nn.Conv2d(inchannel, outchannel, 1, stride),
nn.BatchNorm2d(outchannel)
)
layers = []
layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut)) # 改变通道数量
for i in range(1, block_num + 1):
layers.append(ResidualBlock(outchannel, outchannel))
return nn.Sequential(*layers)
###规定网络数据流向
def forward(self, x):
x = self.pre(x) ###[2,3,32,280]--->[2,64,8,70]
x = self.layer1(x) ###[2,64,8,70]
x = self.layer2(x) ###[2,128,4,35]
x = self.layer3(x) ###[2,256,2,17]
x = self.layer4(x) ###[2,512,1,8]
return x
4、libtorch版主函数
和残差模块一样,分为头文件(.h)和源文件(.cpp)
先写头文件,还是仿照pytorch版的来写,这样我们可以避免很多麻烦
4.1主函数头文件(声明)
//主函数声明
class ResNet_ocrImpl : public torch::nn::Module {
public:
ResNet_ocrImpl(/*std::vector<int> layers, int num_classes = 1000,*/ std::string model_type = "resnet18",
int groups = 1, int width_per_group = 64);
torch::Tensor forward(torch::Tensor x);
std::vector<torch::Tensor> features(torch::Tensor x);
torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
private:
int expansion = 1; bool is_basic = true;
int64_t inplanes = 64; int groups = 1; int base_width = 64;
torch::nn::Conv2d conv1{
nullptr };
torch::nn::BatchNorm2d bn1{
nullptr };
torch::nn::Sequential layer1{
nullptr };
torch::nn::Sequential layer2{
nullptr };
torch::nn::Sequential layer3{
nullptr };
torch::nn::Sequential layer4{
nullptr };
};
TORCH_MODULE(ResNet_ocr);
4.2主函数定义
//先定义层函数_make_layer,这里也是参照pytorch写的
torch::nn::Sequential ResNet_ocrImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) {
torch::nn::Sequential downsample;
if (stride != 1 || inplanes != planes * expansion) {
downsample = torch::nn::Sequential(
torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
torch::nn::BatchNorm2d(planes * expansion)
);
}
torch::nn::Sequential layers;
layers->push_back(Block_ocr(inplanes, planes, stride, downsample, groups, base_width, is_basic));
inplanes = planes * expansion;
for (int64_t i = 1; i < blocks; i++) {
layers->push_back(Block_ocr(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width, is_basic));
}
return layers;
}
//然后定义主函数
ResNet_ocrImpl::ResNet_ocrImpl(/*std::vector<int> layers, int num_classes,*/ std::string model_type, int _groups, int _width_per_group)
{
if (model_type != "resnet18" && model_type != "resnet34")
{
expansion = 4;
is_basic = false;
}
groups = _groups;
base_width = _width_per_group;
conv1 = torch::nn::Conv2d(conv_options(1, 64, 7, 2, 3, 1, false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
layer1 = torch::nn::Sequential(_make_layer(64, 2/*layers[0]*/));
layer2 = torch::nn::Sequential(_make_layer(128, 2/*layers[1]*/, 2));
layer3 = torch::nn::Sequential(_make_layer(256,2 /*layers[2]*/, 2));
layer4 = torch::nn::Sequential(_make_layer(512, 2/*layers[3]*/, 2));
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
}
else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
}
}
}
//resnet主函数-前向传播
torch::Tensor ResNet_ocrImpl::forward(torch::Tensor x) {
x = conv1->forward(x);
x = bn1->forward(x);
x = torch::relu(x);
x = torch::max_pool2d(x, 3, 2, 1);
x = layer1->forward(x);
x = layer2->forward(x);
x = layer3->forward(x);
x = layer4->forward(x);
return x;
}
以上就是;libtorch版的resnet18 网络,完全使用c++搭建的,由于我用resnet需要和别的网络拼接,所以fc层和softmax层给删了,有需要的可以自己填上。这里也是参考一位github大神的手法来写的。
科技无罪、知识无罪,我们要做知识的传播者!
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/142669.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...