libtorch-resnet18

libtorch-resnet18与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:1、首先是pytorch版残差模块classResidualBlock(nn.Module):def__init__(self,inchannel,outchannel,stride=1,shortcut=None):super(ResidualBlock,self).__

大家好,又见面了,我是你们的朋友全栈君。

与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考
首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:
1、首先是pytorch版残差模块

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1),
            nn.BatchNorm2d(outchannel)
        )
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

2、libtorch版残差模块
因为是用c++搭建的,所以先创建头文件
2.1残差模块头文件(声明)

//重载函数
inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
	int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_bias = true) { 
   
	torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
	conv_options.stride(stride);
	conv_options.padding(padding);
	conv_options.bias(with_bias);
	conv_options.groups(groups);
	return conv_options;
}
//残差模块声明
class Block_ocrImpl : public torch::nn::Module { 
   
public:
    Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
		torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
	torch::Tensor forward(torch::Tensor x);
	torch::nn::Sequential downsample{ 
    nullptr };
private:
	bool is_basic = true;
	int64_t stride = 1;
	torch::nn::Conv2d conv1{ 
    nullptr };
	torch::nn::BatchNorm2d bn1{ 
    nullptr };
	torch::nn::Conv2d conv2{ 
    nullptr };
	torch::nn::BatchNorm2d bn2{ 
    nullptr };
	torch::nn::Conv2d conv3{ 
    nullptr };
	torch::nn::BatchNorm2d bn3{ 
    nullptr };
};
TORCH_MODULE(Block_ocr);

2.2残差模块定义
这里我们要在头文件里面写一个卷积的重载函数,省去以后重复写的工作,我把它放在了2的头文件里面

//残差模块定义
Block_ocrImpl::Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_,
    torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{ 
   
    downsample = downsample_;
    stride = stride_;
    int width = int(planes * (base_width / 64.)) * groups;

    conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
    bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    is_basic = _is_basic;
    if (!is_basic) { 
   
        conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
        conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
        conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
        bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
    }

    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("conv2", conv2);
    register_module("bn2", bn2);
    if (!is_basic) { 
   
        register_module("conv3", conv3);
        register_module("bn3", bn3);
    }

    if (!downsample->is_empty()) { 
   
        register_module("downsample", downsample);
    }
}
//残差前向传播
torch::Tensor Block_ocrImpl::forward(torch::Tensor x) { 
   
    torch::Tensor residual = x.clone();

    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);

    x = conv2->forward(x);
    x = bn2->forward(x);

    if (!is_basic) { 
   
        x = torch::relu(x);
        x = conv3->forward(x);
        x = bn3->forward(x);
    }

    if (!downsample->is_empty()) { 
   
        residual = downsample->forward(residual);
    }

    x += residual;
    x = torch::relu(x);

    return x;
}

3、pytorch版resnet主函数

class ResNet18(nn.Module):
    def __init__(self,nc):
        super(ResNet18, self).__init__()
        ###网络输入部分由一个7x7stride=2的卷积核和一个3x3stride=2的最大池化组成
        self.pre = nn.Sequential(
            nn.Conv2d(nc, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )
        ###网络中间卷积部分,通过中间3x3的卷积堆叠来实现信息的提取,下面的2代表bolck的重复堆叠次数
        self.layer1 = self._make_layer(64, 128, 1)

        self.layer2 = self._make_layer(128, 256, 2, stride=(2, 1))

        self.layer3 = self._make_layer(256, 512, 5, stride=(2, 1))

        self.layer4 = self._make_layer(512, 512, 3, stride=(2, 1))


    def _make_layer(self, inchannel, outchannel, block_num, stride=(1, 1)):
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride),
            nn.BatchNorm2d(outchannel)
        )
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))  # 改变通道数量
        for i in range(1, block_num + 1):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)

    ###规定网络数据流向
    def forward(self, x):
        x = self.pre(x)  ###[2,3,32,280]--->[2,64,8,70]
        x = self.layer1(x)  ###[2,64,8,70]
        x = self.layer2(x)  ###[2,128,4,35]
        x = self.layer3(x)  ###[2,256,2,17]
        x = self.layer4(x)  ###[2,512,1,8]
        return x

4、libtorch版主函数
和残差模块一样,分为头文件(.h)和源文件(.cpp)
先写头文件,还是仿照pytorch版的来写,这样我们可以避免很多麻烦
4.1主函数头文件(声明)

//主函数声明
class ResNet_ocrImpl : public torch::nn::Module { 
   
public:
    ResNet_ocrImpl(/*std::vector<int> layers, int num_classes = 1000,*/ std::string model_type = "resnet18",
        int groups = 1, int width_per_group = 64);
    torch::Tensor forward(torch::Tensor x);
    std::vector<torch::Tensor> features(torch::Tensor x);
    torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
private:
    int expansion = 1; bool is_basic = true;
    int64_t inplanes = 64; int groups = 1; int base_width = 64;
    torch::nn::Conv2d conv1{ 
    nullptr };
    torch::nn::BatchNorm2d bn1{ 
    nullptr };
    torch::nn::Sequential layer1{ 
    nullptr };
    torch::nn::Sequential layer2{ 
    nullptr };
    torch::nn::Sequential layer3{ 
    nullptr };
    torch::nn::Sequential layer4{ 
    nullptr };
};
TORCH_MODULE(ResNet_ocr);

4.2主函数定义

//先定义层函数_make_layer,这里也是参照pytorch写的
torch::nn::Sequential ResNet_ocrImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) { 
   

    torch::nn::Sequential downsample;
    if (stride != 1 || inplanes != planes * expansion) { 
   
        downsample = torch::nn::Sequential(
            torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
            torch::nn::BatchNorm2d(planes * expansion)
        );
    }
    torch::nn::Sequential layers;
    layers->push_back(Block_ocr(inplanes, planes, stride, downsample, groups, base_width, is_basic));
    inplanes = planes * expansion;
    for (int64_t i = 1; i < blocks; i++) { 
   
        layers->push_back(Block_ocr(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width, is_basic));
    }

    return layers;
}
//然后定义主函数
ResNet_ocrImpl::ResNet_ocrImpl(/*std::vector<int> layers, int num_classes,*/ std::string model_type, int _groups, int _width_per_group)
{ 
   
    if (model_type != "resnet18" && model_type != "resnet34")
    { 
   
        expansion = 4;
        is_basic = false;
    }
    groups = _groups;
    base_width = _width_per_group;
    conv1 = torch::nn::Conv2d(conv_options(1, 64, 7, 2, 3, 1, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
    layer1 = torch::nn::Sequential(_make_layer(64, 2/*layers[0]*/));
    layer2 = torch::nn::Sequential(_make_layer(128, 2/*layers[1]*/, 2));
    layer3 = torch::nn::Sequential(_make_layer(256,2 /*layers[2]*/, 2));
    layer4 = torch::nn::Sequential(_make_layer(512, 2/*layers[3]*/, 2));
    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("layer1", layer1);
    register_module("layer2", layer2);
    register_module("layer3", layer3);
    register_module("layer4", layer4);
    for (auto& module : modules(/*include_self=*/false)) { 
   
        				if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) { 
   
        					torch::nn::init::kaiming_normal_(
        						M->weight,
        						/*a=*/0,
        						torch::kFanOut,
        						torch::kReLU);
        				}
        				else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) { 
   
        					torch::nn::init::constant_(M->weight, 1);
        					torch::nn::init::constant_(M->bias, 0);
        				}
        			}
        	
}

//resnet主函数-前向传播
torch::Tensor  ResNet_ocrImpl::forward(torch::Tensor x) { 
   
    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);
    x = torch::max_pool2d(x, 3, 2, 1);

    x = layer1->forward(x);
    x = layer2->forward(x);
    x = layer3->forward(x);
    x = layer4->forward(x);
    return x;
}

以上就是;libtorch版的resnet18 网络,完全使用c++搭建的,由于我用resnet需要和别的网络拼接,所以fc层和softmax层给删了,有需要的可以自己填上。这里也是参考一位github大神的手法来写的。
科技无罪、知识无罪,我们要做知识的传播者!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/142669.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 十大移动开发平台

    十大移动开发平台1.PutdbWebBuilder    WebBuilder是一款开源的可视化移动Web应用开发和运行平台。基于浏览器的集成开发环境,可视化和智能化的设计,能轻松完成常规应用和面向手机的移动应用开发;高效、稳定和可扩展的特点,适合复杂企业级应用的运行;跨平台、数据库和浏览器的架构,适应复杂的服务器和客户端环境;包括智能数据库访问在内的多项先进技术,使应用系统的开发更快捷和简单。 …

  • 功能测试用例编写_接口测试用例设计思路

    功能测试用例编写_接口测试用例设计思路编写测试用例HttpRunnerv3.x支持三种测试用例格式pytest,YAML和JSON。官方强烈建议以pytest格式而不是以前的YAML/JSON格式编写和维护测试用例格式关系如下图所示

  • linux lvm 扩容

    linux lvm 扩容

  • VS2013 密钥 – 所有版本

    VS2013 密钥 – 所有版本VS2013密钥–所有版本VisualStudioUltimate2013KEY(密钥):BWG7X-J98B3-W34RT-33B3R-JVYW9VisualStudioPremium2013KEY(密钥):FBJVC-3CMTX-D8DVP-RTQCT-92494VisualStudioProfessional2013KEY(密钥):XDM3T-W3T3V…

  • Java中&、|、&&、||详解

    Java中&、|、&&、||详解1、Java中&amp;叫做按位与,&amp;&amp;叫做短路与,它们的区别是:&amp;既是位运算符又是逻辑运算符,&amp;的两侧可以是int,也可以是boolean表达式,当&amp;两侧是int时,要先把运算符两侧的数转化为二进制数再进行运算,而短路与(&amp;&amp;)的两侧要求必须是布尔表达式。举例如下:12&amp;5的值是多少?答:12转成二进制数是1100(前四…

  • 怎样安装pip_pip 安装本地python包

    怎样安装pip_pip 安装本地python包网上有各种方法安装pip,针对不同的系统方法还不一样,最后发现还是下面这种方法最简单,直接了当干脆方便,适用于Windows和Linux。(1)下载pip进入https://pypi.python.org/pypi/pip,下载第二项。(2)解压安装解压下载的文件(windows下只用解压工具解压如RAR,Linux下终端输入tar-xfpip-9.0.1.tar.gz,即tar-xf文件名

    2022年10月27日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号