tensorflow2.2_实现Resnet34_花的识别[通俗易懂]

tensorflow2.2_实现Resnet34_花的识别[通俗易懂]残差块    Resnet是由许多残差块组成的,而残差块可以解决网络越深,效果越差的问题。    残差块的结构如下图所示。其中:weightlayer表示卷积层,用于特征提取。F(x)F(x)F(x)表示经过两层卷积得到的结果。xxx表示恒等映射。F(x)+xF(x)+xF(x)+x表示经过两层卷积后与之前的卷积层进行结合。所以F(x)F(x)F(x)和xxx代表的是相同的信号。作用:将浅层网络的信号递给深层网络,使网络得到更好的结果。批量归一化(BatchNormaliz

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

残差块

    Resnet是由许多残差块组成的,而残差块可以解决网络越深,效果越差的问题
    残差块的结构如下图所示。
在这里插入图片描述
其中:

  1. weight layer表示卷积层,用于特征提取。
  2. F ( x ) F(x) F(x)表示经过两层卷积得到的结果。
  3. x x x表示恒等映射
  4. F ( x ) + x F(x)+x F(x)+x表示经过两层卷积后与之前的卷积层进行结合。

所以 F ( x ) F(x) F(x) x x x代表的是相同的信号。

  • 作用:将浅层网络的信号递给深层网络,使网络得到更好的结果。

批量归一化(Batch Normalization)

    我们暂时简称它为BN。
    BN可以对网络中的每一层的输入,输出特征进行标准化处理,将他们变成均值为0,方差为1的分布。
标准化的公式如下:
在这里插入图片描述
其中:

  • x n x_n xn表示第n个维度的数据
  • μ μ μ为该维度的平均值
  • σ σ σ表示该维度的方差
  • ϵ ϵ ϵ表示一个很小很小的值,防止分母为零

BN的主要作用:

  1. 加快模型的收敛速度。
  2. 增强正则化的作用。

Resnet34网络结构

如下图:
在这里插入图片描述
其中:

  • 7×7 conv 表示7×7大小的卷积核的窗口
  • 3×3 conv 表示3×3大小的卷积核的窗口
  • 64、128、256、512表示特征图的数量
  • /2 表示卷积核的步长,没写就默认为1
  • 虚线表示无法直接连接,因为生成的特征图数量是不一样的,也就是说shape是不一样的,一般是使用步长为2、大小为1的卷积核来对输入信号进行特征提取,使输入信号和输出信号的shape一致,再进行结合。

代码演示

1. 导入相关库

可新建一个train.py文件

from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam

2. 定义网络结构

# 结构快
def block(x, filters, strides=2, conv_short=True):
if conv_short:
short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
else:
short_cut = x
# 2层卷积
x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
x = BatchNormalization(epsilon=1.001e-5)(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization(epsilon=1.001e-5)(x)
x = Activation('relu')(x)
x = Add()([x, short_cut])
x = Activation('relu')(x)
return x
def Resnet34(inputs, classes):
x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=128, strides=2, conv_short=True)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=256, strides=2, conv_short=True)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=512, strides=2, conv_short=True)
x = block(x, filters=512, strides=1, conv_short=False)
x = block(x, filters=512, strides=1, conv_short=False)
x = GlobalAvgPool2D()(x)
x = Dense(classes, activation='softmax')(x)
return x

3. 定义超参数

数据集:
链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
提取码:bhjx
复制这段内容后打开百度网盘手机App,操作更方便哦
权重文件:
链接:https://pan.baidu.com/s/1JotFy2G5wdThj409K87ExA
提取码:4vi5
复制这段内容后打开百度网盘手机App,操作更方便哦

数据集格式:
test和train文件夹里面需要按类别存放,如下

- dataset
- data1_dog_cat
- test
- cat
- cat.10000.jpg
- cat.10001.jpg
- ...
- dog
- dog.10000.jpg
- dog.10001.jpg
- train
- cat
- cat.0.jpg
- cat.1.jpg
- ...
- dog
- dog.0.jpg
- dog.1.jpg
- ...
classes = 17 # 需要分类的类别
batch_size = 16 # 批次大小
epochs = 100 # 轮次
img_size = 224 # 图片大小
lr = 1e-3 # 学习率大小
datasets = './dataset/data_flower' # 数据集的路径
weight = './model_data/test_acc0.794-resnet18val_loss0.857-flower.h5' # 权重文件的路径
# ------------------------------- #
# 我们使用加载权重的方式进行训练,效果会更好

4. 定义数据处理的构造器

train_data = ImageDataGenerator(
rotation_range=20, 
width_shift_range=0.1, 
height_shift_range=0.1,
rescale=1/255.0,
shear_range=10,
zoom_range=0.1,
horizontal_flip=True,
brightness_range=(0.7, 1.3),
fill_mode='nearest'
)
test_data = ImageDataGenerator(
rescale=1/255
)
train_generator = train_data.flow_from_directory(
f'{ 
datasets}/train',
target_size=(img_size, img_size),
batch_size=batch_size
)
test_generator = test_data.flow_from_directory(
f'{ 
datasets}/test',
target_size=(img_size, img_size),
batch_size=batch_size
)

5. 定义学习率回调函数

def adjust_lr(epoch, lr=lr):
print("Seting to %s" % (lr))
if epoch < 10:
return lr
else:
return lr * 0.93

6. 主函数

需要新建一个logs文件夹,保存权重文件。

if __name__ == '__main__':
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
inputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=Resnet34(inputs=inputs, classes=classes))
callbackss = [
EarlyStopping(monitor='val_loss', patience=10, verbose=1),
ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',
save_weights_only=True, save_best_only=False, period=1),
LearningRateScheduler(adjust_lr)
]
if weight:
print('---------->loding weight--------->')
model.load_weights(weight, by_name=True, skip_mismatch=True)
model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(
x                      = train_generator,
validation_data        = test_generator,
workers                = 1,
epochs                 = epochs,
callbacks              = callbackss
)
else:
print('---------->epoch0 starting--------->')
model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(
x                    = train_generator,
validation_data      = test_generator,
epochs               = epochs,
workers              = 1,
callbacks            = callbackss
)

7. 预测图片

可新建一个predict.py文件
导入库


from PIL import Image
from tensorflow.keras.layers import Input
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, GlobalAvgPool2D, Input, BatchNormalization, Activation, Add

定义归一化函数

def preprocess_input(x):
x /= 255
return x

定义转RGB函数

def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image 
else:
image = image.convert('RGB')
return image 

定义参数
注意:weight需要指定训练好的权重文件

datasets = './dataset/data_flower/test'
names = os.listdir(datasets)
weight = './model_data/test_acc0.860-val_loss0.599-resnet34-flower.h5'
net = Resnet34
classes = 17
img_size = 224

定义网络模型

# 结构快
def block(x, filters, strides=2, conv_short=True):
if conv_short:
short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
else:
short_cut = x
# 2层卷积
x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
x = BatchNormalization(epsilon=1.001e-5)(x)
x = Activation('relu')(x)
x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization(epsilon=1.001e-5)(x)
x = Activation('relu')(x)
x = Add()([x, short_cut])
x = Activation('relu')(x)
return x
def Resnet34(inputs, classes):
x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=64, strides=1, conv_short=False)
x = block(x, filters=128, strides=2, conv_short=True)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=128, strides=1, conv_short=False)
x = block(x, filters=256, strides=2, conv_short=True)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=256, strides=1, conv_short=False)
x = block(x, filters=512, strides=2, conv_short=True)
x = block(x, filters=512, strides=1, conv_short=False)
x = block(x, filters=512, strides=1, conv_short=False)
x = GlobalAvgPool2D()(x)
x = Dense(classes, activation='softmax')(x)
return x
inputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=Resnet34(inputs=inputs, classes=classes))
# -------------------------------------------------#
# 载入模型
# -------------------------------------------------#
model.load_weights(weight)
while True:
img_path = input('input img_path:')
try:
img = Image.open(img_path)
img = cvtColor(img)
img = img.resize((224, 224))
image_data = np.expand_dims(preprocess_input(np.array(img, np.float32)), 0)
except:
print('The path is error!')
continue
else:
plt.imshow(img)
plt.axis('off')
p =model.predict(image_data)[0]
pred_name = names[np.argmax(p)]
plt.title('%s:%.3f'%(pred_name, np.max(p)))
plt.show()

效果如下:
在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/188086.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • pycharm选中一行代码快捷键_python代码自动对齐

    pycharm选中一行代码快捷键_python代码自动对齐在写代码的时候,经常为了对齐代码而烦恼,强大的pycharm为我们提供了一个代码自动对齐功能,而且可以使用快捷键完成。快捷键组合是:Ctrl+Alt+L将光标置于需要调整的代码行,或者选择一个区域,按下快捷键,代码就可以自动对齐啦!…

  • 图解springmvc 执行流程

    图解springmvc 执行流程核心对象DispatcherServlet核心控制器负责请求,响应,数据的分发。HandlerMapping处理器映射器,负责到controller中,找到对应的方法,返回给核心控制器。HandleAdapter处理适配器,将handle找到的方法执行,执行结果,即ModelAndView数据和视图返回给核心控制器。HttpMessageConvertor消息转换器,数据类型的转换,如日期…ViewResolver视图解析器,核心控制器调度视图解析器,视图解析器,返回视图。核心控制

  • WebApp开发-Google官方教程

    WebApp开发-Google官方教程概览你可以使用viewport的元数据、CSS和Javascript来为不同分辨率的屏幕设置合适的页面本文档中的技术适用于Android 2.0及以上设备,针对默认的Android Browser中及在WebView中呈现的页面如果你在为Android开发Web应用或者在为移动设备重新设计一个Web应用,你需要仔细考虑在不同设备上你的页面看起来是怎样的。因为Android设备有不同款型

  • 恋空 By whaosoft「建议收藏」

    恋空 By whaosoft「建议收藏」/序曲 如果那天,我没有遇见你。我想,我就不会感到如此痛苦、如此悲伤、如此难过、如此令人悲从中来了。但是,如果我没有遇见你。我也不会知道那么欢愉、那么温柔、那么相爱、那么温暖、那么幸福的心情了……噙着泪水的我,今天,依旧仰望着天空。 仰望着天空。I.虚幻的开始1 『哇~!!肚子超饿的啦~』期待已久的午休时间终于到了。美嘉一如往常地打开桌上的便当。来上学真的是麻烦事一大堆

  • js自动生成二维码_jquery 生成二维码无法识别

    js自动生成二维码_jquery 生成二维码无法识别生成二维码并保存为图片,点击下载此二维码简单实现的效果,如有更好的请指教利用jquery加jquery.qrcode //外部的js <scriptsrc=”./jquery/2.1.4/jquery.min.js”></script><scriptsrc=”./jquery/jquery.qrcode.min.js”></scri…

    2022年10月18日
  • linux rpm卸载包及其依赖,Linux下如何用rpm卸载软件 rpm依赖包强制卸载

    linux rpm卸载包及其依赖,Linux下如何用rpm卸载软件 rpm依赖包强制卸载以Mysql为例。#查看安装的Mysql版本sjgx2:/usr/local/mysql/bin#rpm-qa|grep-imysqlMySQL-client-5.1.17-0.glibc23MySQL-server-5.1.17-0.glibc23#卸载sjgx2:/usr/local/mysql/bin#rpm-eMySQL-client-5.1.17-0.glibc23s…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号