ResNet34+Unet(可以直接用)

ResNet34+Unet(可以直接用)importtorchfromtorchimportnnimporttorch.nn.functionalasF#因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码classResidualBlock(nn.Module):def__init__(self,inchannel,outchannel,stride,shortcut=None):super(ResidualBlock,self).__init__(

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

import torch
from torch import nn
import torch.nn.functional as F
# 因为ResNet34包含重复的单元,故用ResidualBlock类来简化代码
class ResidualBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride, shortcut=None):
super(ResidualBlock, self).__init__()
self.basic = nn.Sequential(
nn.Conv2d(inchannel, outchannel, 3, stride, 1,
bias=False),  # 要采样的话在这里改变stride
nn.BatchNorm2d(outchannel),  # 批处理正则化
nn.ReLU(inplace=True),  # 激活
nn.Conv2d(outchannel, outchannel, 3, 1, 1,
bias=False),  # 采样之后注意保持feature map的大小不变
nn.BatchNorm2d(outchannel),
)
self.shortcut = shortcut
def forward(self, x):
out = self.basic(x)
residual = x if self.shortcut is None else self.shortcut(x)  # 计算残差
out += residual
return nn.ReLU(inplace=True)(out)  # 注意激活
class Conv2dReLU(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
):
super(Conv2dReLU, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.bn(self.conv(x)))
return x
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class SegmentationHead(nn.Sequential):
def __init__(self,
in_channels=16,
out_channels=1,
kernel_size=3,
upsampling=1):
conv2d = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(
scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)
# ResNet类
class Resnet34(nn.Module):
def __init__(self, inchannels):
super(Resnet34, self).__init__()
self.pre = nn.Sequential(
nn.Conv2d(inchannels, 64, 7, 2, 3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 1),
)  # 开始的部分
self.body = self.makelayers([3, 4, 6, 3])  # 具有重复模块的部分
in_channels = [512, 256, 128, 128, 32]
skip_channels = [256, 128, 64, 0, 0]
out_channels = [256, 128, 64, 32, 16]
blocks = [
DecoderBlock(in_ch, skip_ch,
out_ch) for in_ch, skip_ch, out_ch in zip(
in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)
self.seg = SegmentationHead()
def makelayers(self, blocklist):  # 注意传入列表而不是解列表
self.layers = []
for index, blocknum in enumerate(blocklist):
if index != 0:
shortcut = nn.Sequential(
nn.Conv2d(64 * 2**(index - 1),
64 * 2**index,
1,
2,
bias=False),
nn.BatchNorm2d(64 * 2**index))  # 使得输入输出通道数调整为一致
self.layers.append(
ResidualBlock(64 * 2**(index - 1), 64 * 2**index, 2,
shortcut))  # 每次变化通道数时进行下采样
for i in range(0 if index == 0 else 1, blocknum):
self.layers.append(
ResidualBlock(64 * 2**index, 64 * 2**index, 1))
return nn.Sequential(*self.layers)
def forward(self, x):
self.features = []
# 下采样
# x = self.pre(x)
for i, l in enumerate(self.pre):
x = l(x)
if i == 2:
self.features.append(x)
print("y=", len(self.features))
for i, l in enumerate(self.body):
if i == 3 or i == 7 or i == 13:
self.features.append(x)
x = l(x)
skips = self.features[::-1]
# skips = self.features[1:]
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
x = self.seg(x)
return x



四次Skipconnect分别在:Maxpool前;另外三次在通道数变化前。
上采样combine时采用的是插值(nn.functionnal.interpolate)。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/185682.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • python flask教程_python框架有哪些

    python flask教程_python框架有哪些大家好,这算是我使用CSDN以来第一次正二八经的想自己写一篇博客。如果有写的不好的地方还请大家见谅!使用pipenv的方便之处就是可以单独的为每一个python 项目建立对应的虚拟环境,而且该过程简单方便。下面我会用简短的步骤来描述这个过程:1. 首先使用pip进行安装pipenv。 用管理员身份打开命令行(cmd),然后输入pipinstallpipenv 回车,结果如下图所…

  • mybatishelperpro激活码_在线激活

    (mybatishelperpro激活码)最近有小伙伴私信我,问我这边有没有免费的intellijIdea的激活码,然后我将全栈君台教程分享给他了。激活成功之后他一直表示感谢,哈哈~IntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,下面是详细链接哦~https://javaforall.cn/100143.html2K…

  • java 输出_java怎么输出

    java 输出_java怎么输出展开全部java控制台输出由print()和println()来完成最为简单。这两种方法由rintStream(System.out引用32313133353236313431303231363533e78988e69d8331333365643661的对象类型)定义。尽管System.out是一个字节流,用它作为简单程序的输出是可行的。因为PrintStream是从OutputStrea…

  • UNITY ET 框架

    UNITY ET 框架GITHUB上近3000星的开源框架,包括了服务器客户端,ILRUNTIME热等特点,对于新项目,值得拥有

  • linux rpm 卸载 java_linux下用rpm 安装卸载jdk「建议收藏」

    linux rpm 卸载 java_linux下用rpm 安装卸载jdk「建议收藏」1、如果linux是centos的话,请先卸载openjdkjava-version,会有下面的信息:卸载默认的用root用户登陆到系统,打开一个终端输入#rpm-qa|grepgcj显示内容其中包含下面两行信息#java-1.4.2-gcj-compat-1.4.2.0-27jpp#java-1.4.2-gcj-compat-devel-l.4.2.0-27jpp卸载#rpm-…

  • weblogic环境,应用上传图片报Could not initialize class sun.awt.X11.XToolkit

    weblogic环境,应用上传图片报Could not initialize class sun.awt.X11.XToolkit

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号