tensorflow模型查看参数(pytorch conv2d函数详解)

tf.nn.conv2d()参数解析

大家好,又见面了,我是你们的朋友全栈君。

定义:
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
功能:将两个4维的向量input(样本数据矩阵)和filter(卷积核)做卷积运算,输出卷积后的矩阵
input的形状:[batch, in_height ,in_width, in_channels]
batch: 样本的数量
in_height :每个样本的行数
in_width: 每个样本的列数
in_channels:每个样本的通道数,如果是RGB图像就是3
filter的形状:[filter_height, filter_width, in_channels, out_channels]
filter_height:卷积核的高
filter_width:卷积核的宽
in_channels:输入的通道数
out_channels:输出的通道数
比如在tensorflow的cifar10.py文件中有句:
这里写图片描述
卷积核大小为 5*5,输入通道数是3,输出通道数是64,即这一层输出64个特征
在看cifar10.py里第二层卷积核的定义:
这里写图片描述
大小依然是5*5,出入就是64个通道即上一层的输出,输出依然是64个特征
strides:[1,stride_h,stride_w,1]步长,即卷积核每次移动的步长
padding:填充模式取值,只能为”SAME”或”VALID”
卷积或池化后的节点数计算公式:
output_w = int((input_w + 2*padding – filter_w)/strid_w) + 1
举例说明:
假设这里使用的图像每副只有一行像素一通道,共3副图像

>>> a = np.array([[1,1,1],[2,2,2],[3,3,3]])
>>> b=tf.reshape(a,[a.shape[0],1,a.shape[1],1])
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(b)
array([[[[1], [1], [1]]], [[[2], [2], [2]]], [[[3], [3], [3]]]])

然后设有2个1*2的卷积核

>>> k=tf.constant([[[[ 1.0, 1.0]],[[2.0, 2.0]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3., 3.], [ 3., 3.], [ 1., 1.]]], [[[ 6., 6.], [ 6., 6.], [ 2., 2.]]], [[[ 9., 9.], [ 9., 9.], [ 3., 3.]]]], dtype=float32)
>>> sess.run(b)
array([[[[ 1.], [ 1.], [ 1.]]], [[[ 2.], [ 2.], [ 2.]]], [[[ 3.], [ 3.], [ 3.]]]], dtype=float32)
>>> sess.run(k)
array([[[[ 1., 1.]], [[ 2., 2.]]]], dtype=float32)

这里写图片描述
最后的0是函数自动填充的,所以最后就得到了一个2通道的卷积结果
将k改成[[ 1.0, 0.5],[2, 1]]然后再次运行:

>>> k=tf.constant([[[[ 1.0, 0.5]],[[2, 1]]]], dtype=tf.float32)
>>> mycov=tf.nn.conv2d(b, k, [1, 1, 1, 1], padding='SAME')
>>> init = tf.initialize_all_variables()
>>> sess.run(init)
>>> sess.run(mycov)
array([[[[ 3. , 1.5], [ 3. , 1.5], [ 1. , 0.5]]], [[[ 6. , 3. ], [ 6. , 3. ], [ 2. , 1. ]]], [[[ 9. , 4.5], [ 9. , 4.5], [ 3. , 1.5]]]], dtype=float32)

卷积核一般用tf.get_variable()初始化,这里为了演示直接指定为常量

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/129729.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • ICMP数据包分析_Wireshark数据包分析实战

    ICMP数据包分析_Wireshark数据包分析实战一.实验目的1.学习和掌握ICMP协议的基本作用和报文格式2.理解ICMP协议与IP协议的封装关系3.学习和掌握ICMP协议的应用和报文格式4.理解tracertoute工作过程二.实验拓扑三.实验工具GNS3和Wireshark抓包分析软件四.ICMP协议的封装格式(1)Type类型值,标识ICMP分组类型(2)Code代码值,标识ICMP分组类型的某一种具体分组(3)Checksum校验和,用于检验数据包是否完整或是否被修改(4)Identifier标识符,标识本进程

    2022年10月21日
  • 搭建一个属于自己的语音对话机器人

    搭建一个属于自己的语音对话机器人

  • linux下svn配置http访问「建议收藏」

    linux下svn配置http访问「建议收藏」CentOS服务器部署svn+apachehttp+sslhttps访问,本文详细介绍了svn配置apachehttp访问安装及配置过程。

  • 9.7 StringTokenizer类

    9.7 StringTokenizer类StringTokenizer类:解析字符串单词和split方法不同的是,StringTokenizer对象不使用正则表达式做分隔标记有时候要分析字符串并将字符串分解成可独立使用的单词,这些单词称为语言符号。对于字符串“Iamstudent”,如果把空格作为该字符串的标记,那么该字符串有三个单词(语言符号)。对于字符串“I,am,student”,如果把逗号作为该字符串的标…

  • 动态库依赖关系_查看运行的动态库

    动态库依赖关系_查看运行的动态库1前言这两天在编写一个插件系统Demo的时候,发现了个很奇怪的问题:插件加载器中已经链接了ld库,但是应用程序在链接插件加载器的时候,却还需要显式的来链接ld库。否则就会报:DSOmissingfromcommandline。这个报错翻译过来就是没有在命令行中指定该动态库。这个报错就很搞事了,你说你明明知道需要哪个库,为什么不直接帮我链接呢,非得我显示的在命令行中指定呢?2现象描…

  • springMVC 配置CharacterEncodingFilter之后不起效果

    springMVC 配置CharacterEncodingFilter之后不起效果最近开始自学springMVC框架,遇到中文乱码这一经典问题,记录下解决过程,以便后续忘记web.xml里过滤器配置如下:<?xmlversion=”1.0″encoding=”UTF-8″?><web-appxmlns:xsi=”http://www.w3.org/2001/XMLSchema-instance”xmlns=”http://j…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号