语义分割代码一步步实现_语义分割应用

语义分割代码一步步实现_语义分割应用语义分割的整体实现代码大致思路很简单,但是具体到细节,就有很多可说的东西。之前写过一篇文章,可能有些地方现在又有了新的思路或者感受,或者说之前没有突出重点。作为一个小白,这里把自己知道的知识写一下,事无巨细,希望看到的人能有所收获。一、文件思路总的来说,语义分割代码可以分为如下几个部分:data:图像数据 data/train:训练集数据 data/train/img:…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

语义分割的整体实现代码大致思路很简单,但是具体到细节,就有很多可说的东西。

之前写过一篇文章,可能有些地方现在又有了新的思路或者感受,或者说之前没有突出重点。

作为一个小白,这里把自己知道的知识写一下,事无巨细,希望看到的人能有所收获。

一、文件思路

总的来说,语义分割代码可以分为如下几个部分:

语义分割代码一步步实现_语义分割应用语义分割代码一步步实现_语义分割应用

  • data:图像数据
  • data/train:训练集数据
  • data/train/img:训练集原始图像img
  • data/train/label:训练集原始图像label
  • data/val:验证集数据
  • data/val/img:验证集原始图像img
  • data/val/label:验证集原始图像label
  • dataset:将本地数据转化成pytorch对应的DataSet的文件
  • model:网络模型
  • utils:工具文件
  • utils/args:参数类
  • utils/utils:通用方法类
  • train.py:训练网络代码

当然,这只是一种划分文件的思路,还有很多不错的思路,大家选择一种即可。

二、代码实现思路

代码实现思路其实就是对上面文件的诠释了。

1、图像数据

没有图像数据啥也做不了,所以我们首先要从数据说起。

针对数据来讲,有哪些需要注意的事项呢?

  1. 图像数据是否过大
  2. 图像数据是否需要增强预处理
  3. 图像数据是否需要提前切分为测试集和验证集

1、图像数据过大

当图像数据过大时,很容易造成内存满的问题,导致我们训练失败。

方法:A、采用cv2.resize将图像缩小。B、将图像split为小图像。

2、图像提前预处理

图像提前预处理是为了让图像更好的去训练,如果原始图像存在过于模糊等问题,那么我们就需要做一些预处理操作。

方法:采用各种数据增强库,如:albumentations库,对图像亮度、对比度、锐度等进行增强。

3、图像数据是否提前切分为测试集和验证集

一般来说,我们在代码实现阶段可以将图像进行切分,当然,如果图像数据表示很明显简单,我们完全可以手动将数据分为测试集和验证集,这就免了在代码中实现对图像读取切分等操作了,自愿而为。

2、将本地图像数据集转化为pytorch的DataSet

本地图像数据执行完第一步之后,我们便来到了这一步。

为什么要将本地图像数据集转化为pytorch的DataSet呢?

这是因为我们要使用pytorch中的DataLoader类,DataSet作为DataLoader类的参数,必须满足pytorch的要求。

具体怎么实现呢?很简单,大家可以上网搜一下:如何将数据转化为pytorch的数据集。这里简单说一下。

class DataSet(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return len(img_list)

    def __getitem__(self, idx):
        return img, label

如上所示,我们只需写一个类,继承torch.utils.data.Dataset类,然后重写它的__len__()方法和__getitem__()方法即可。

其中__len__()方法是返回数据集大小,__getitem__()方法是返回对应idx的img和label。

这里又要说一个重点了!!!

  1. 图像数据增强
  2. 图像数据对应矩阵数据格式
  3. img和label的处理
  4. 数据集切分

1、图像数据增强

这里的增强不同于之前的图像数据离线预处理,图像数据预处理是为了让图像变得更好,让模型更容易训练。

而这里的图像在线增强是为了让图像变坏,增大训练难度,比如反转等。

一般使用:

class torchvision.transforms.Compose(transforms)

不过这里也有两个重要的操作,比如:(一般我们是要对img进行如下处理)

class torchvision.transforms.Normalize(mean, std)

给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std。

class torchvision.transforms.ToTensor

把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor

2、图像数据对应矩阵数据格式

为什么说这个问题呢,因为这个对应了你使用什么损失函数。如果你使用了交叉熵损失,你就要将label转化为long形式,如果你使用MSE损失,那么你就要将label转化为float形式,这个可以在报错的时候再改正。

3、img和label的处理

一个重点!!!

img和label的处理主要是其维度的处理,当然,这个东西我也不是太理解具体细节。

但是,我知道的是,维度取决于你采用什么损失函数

如果你采用MSE,那么你就要将label处理成(分类数,label.shape[0], label.shape[1])三维,这样,在计算的时候,label的维度就变成了(batch_size, NUM_CLASSES, label.shape[0], label.shape[1])四维,然后和模型输出output四维进行计算。

如果你采用交叉熵,那么你不用对label维度进行处理,这样label计算时候的维度就是(batch_size, label.shape[0], label.shape[1])三维,因为交叉熵计算的(output, label)固定label比output少一维。

一个重点!!!

label归一化后,处理成mask形式,也就是对每个像素打了标签。

如果是二分类,则将label处理成0、1矩阵,如果三分类,则将label处理成0、1、2矩阵。

4、数据集切分

当你处理完数据后,可以用代码划分数据集,例如:

采用这个方法:from torch.utils.data import random_split
dataset = BagDataset(transform)

train_size = int(0.9 * len(dataset))  # 整个训练集中,百分之90为训练集
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])  # 划分训练集和测试集

3、网络模型

实现了第1步和第2步后,基本上我们的数据处理就差不多了,简述一下我们之前做了什么:

  • 数据预处理、切分
  • 数据归一化、维度变换
  • 数据集切分

然后我们的数据阶段基本就结束了,然后我们就开始写模型了,当然,这部分不做多阐述,因为模型基本上一搜一堆,这不是我们讲本文章的重点。具体怎么写一个网络结构需要大家自己去学习了。

4、args和utils

args主要是一些参数的设置,比如:

import os
import torch

class Args():
    def __init__(self):
        super().__init__()

        # 输入通道数
        self.in_channel = 1
        # 几分类问题
        self.NUM_CLASSES = 2

        # 设置太大会内存溢出
        self.batch_size = 3
        # 线程数
        self.num_workers = 4
        # 学习率
        self.lr = 0.001

        # 数据集根目录
        self.path = 'data'

        # 训练集地址
        self.train_mode = 'train'
        # 验证集地址
        self.val_mode = 'val'

        # 设置多gpu
        self.gpu_ids = [0, 1, 2]
        os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2'
        self.cuda = torch.cuda.is_available()

utils主要是一些方法,具体你需要哪些方法,都可以写进去,比如切分数据集,合并数据集,判断路径是否存在,图像转化等等。

5、train.py

DataSet有了、Model有了,接下来到了最重要的部分,就是讲data—–>model了。

先讲训练trainer的部分。

  1. 设置DataLoader,将数据集传入
  2. 获得模型,设置多GPU并行
  3. 设置优化器
  4. 设置损失标准
  5. 从dataloader中获取数据
  6. 优化器梯度设置0
  7. 将img传入net获得output
  8. 计算output和label的损失
  9. 损失反向传播
  10. 优化器执行下一步
  11. 执行一段时间,保存net模型

1、设置DataLoader,将数据集传入,如:

self.train_dataset = DataSet(path=self.args.path, mode=self.args.train_mode)

self.train_img_loader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, shuffle=False,
                                num_workers=self.args.num_workers)

2、获得模型,设置多GPU并行,如:

self.net = UNet(self.args.in_channel, self.args.NUM_CLASSES).cuda()
self.net = nn.DataParallel(self.net, self.args.gpu_ids)

3、设置优化器

4、设置损失标准,如:

self.optimizer = torch.optim.Adam(self.net.parameters())
self.criterion = torch.nn.CrossEntropyLoss().cuda()

5、从dataloader中获取数据

6、优化器梯度设置0

7、将img传入net获得output

8、计算output和label的损失

9、损失反向传播

10、优化器执行下一步,如:

        self.net.train()
        for i,[img, label] in enumerate(self.train_img_loader):
            self.optimizer.zero_grad()
            label = label.long()
            if self.args.cuda:
                img, label = img.cuda(), label.cuda()
            output = self.net(img)
            loss = self.criterion(output, label)
            loss.backward()
            self.optimizer.step()

11、执行一段时间,保存net模型

这里有哪些需要注意的呢?

  1. 多GPU时,.cuda()写在model、criterion、img、label的后面。
  2. 可是使用一些输出的控件进行显示。

再说valid验证的部分(当然,这部分可有可无)

这里只说注意事项!!!

  1. 验证的时候我们的模型是固定参数的了,所以这里不能写net.train()了,要写net.eval()
  2. 验证的时候因为模型参数不用变化,所以没有优化器的设置,不需要损失的反向传播

6、测试test

这里又多加了一个test用来测试,为什么要说这个呢,这是因为这里有很多细节的东西需要说一下。

之前的操作分别对img、label、output做了哪些操作呢?

举个例子:

img的操作基本为:

输入灰度图(二维[W, H])–>Resize成[1, H, W](为什么要将其resize成3维呢,这是因为net的输入必须是4维的,在DataLoader中加上BatchSize变成了[B, 1, W, H],正好满足输入要求)–>标准化/归一化(0~1之间)–>在DataLoader中变为四维[B,1,W,H](其实到这里就完成了,但是如果想要再次输出的话)–>Resize成三维[B, W, H]–>*std–>+mean–>然后取batchsize中每一张图二维[W, H]即可。

label的操作基本为(如果采用CrossEntropy损失函数):

输入灰度图(二维[W, H])–>将灰度图encode成segmap(如果是像素二分类,则变为0-1矩阵,分别对应不同的分类)–>在DataLoader中变成了三维[B, W, H](这样就可以和output四维计算交叉熵损失了,交叉熵损失两个参数的维度分别是:[n,n-1], 其实到这里就完了,如果想要再次输出的话)–>取每一张图二维[W, H]Decode成0~255的矩阵即可。

output的操作基本为:

输出为4位矩阵[B, NUM_CLASSES, W, H]–>采用torch.argmax(output, dim=1).cpu().numpy()将output各个channel图片融合,并降维为[B, W, H]–>取每个batchsize的输出二维[W, H]decode成0~255的矩阵即可。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/171655.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • delphi数组排序_sql排序函数

    delphi数组排序_sql排序函数var  m_bSort:  boolean=false;  //控制正反排序的变量     //ListView排序的回调函数,默认的是快速排序法,也可以自己在这里做算法  function  CustomSortProc(Item1,  Item2:  TListItem;  ParamSort:  integer):  integer;  stdcall

  • 牛屎芯片 | 硬件之家「建议收藏」

    牛屎芯片 | 硬件之家「建议收藏」前言:牛屎芯片又叫邦定芯片或软封装芯片,一般应用于价格较为低廉的电子设备中。原文链接:http://www.allchiphome.com/post/cow-shit_chip一、牛屎芯片牛屎芯

  • plc梯形图编程入门编程_梯形图编程语言由什么组成

    plc梯形图编程入门编程_梯形图编程语言由什么组成梯形图(LAD)是PLC编程的最佳可视化语言,它看起来非常类似于继电器电路图,因此如果你对继电器控制和电子电路有所了解的话,那么学起来会非常容易!在这个教程中,我们将学习关于使用梯形图进行PLC编程的有关知识。现在,让我们开始吧!什么是梯形图梯形图是一种PLC编程语言,也被称为梯形逻辑(LadderLogic)。之所以称为梯形图,是因为这种程序由一条条水平线构成,看起来很像梯子。梯形…

    2022年10月19日
  • Sublime Text3 如何安装、删除及更新插件

    Sublime Text3 如何安装、删除及更新插件1、打开SublimeText3,按Ctrl+`(和qq输入法快捷切换冲突,可以修改qq的输入法切换热键)2、复制粘黏以下代码添加至命令行,然后回车(功能:安装插件的工具,有了它,以后安装其他插件更方便)importurllib.request,os;pf=’PackageControl.sublime-package’;ipp=sublime.inst…

  • Android MVP+RxJava+Retrofit (3) MVP+RxJava+Retrofit

    Android MVP+RxJava+Retrofit (3) MVP+RxJava+Retrofit

  • binlog日志记录什么内容_mysqlbinlog日志在哪

    binlog日志记录什么内容_mysqlbinlog日志在哪(一)binlog介绍binlog,即二进制日志,它记录了数据库上的所有改变,并以二进制的形式保存在磁盘中;它可以用来查看数据库的变更历史、数据库增量备份和恢复、Mysql的复制(主从数据库的复制)。(二)binlog格式binlog有三种格式:Statement、Row以及Mixed。–基于SQL语句的复制(statement-basedreplic

    2022年10月14日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号