【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()

【tensorflow】MTCNN网络基本函数bbox_ohem&landmark_ohem()tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来importtensorflowastfimportnumpyasnpa=tf.constant([1,2,3,4])b=tf.square(a)withtf.Session()assess:print(“b:%s”%sess.run(b))#b:[14916]…

大家好,又见面了,我是你们的朋友全栈君。

tf.gather:用一个一维的索引数组,将张量中对应索引的向量提取出来

import tensorflow as tf
import numpy as np
a = tf.constant([1,2,3,4])
b = tf.square(a)
with tf.Session() as sess:
    print("b:%s" % sess.run(b))
# b:[ 1  4  9 16]
import numpy as np
import tensorflow as tf
def bbox_ohem(bbox_pred,bbox_target,label):
    '''
    :param bbox_pred:
    :param bbox_target:
    :param label: class label
    :return: mean euclidean loss for all the pos and part examples
    '''
    zeros_index = tf.zeros_like(label, dtype=tf.float32)
    ones_index = tf.ones_like(label, dtype=tf.float32)
    #获取pos样本和part样本
    valid_inds = tf.where(tf.equal(tf.abs(label),1),ones_index,zeros_index)
    #(batch,)
    #计算平方和(按行)tf.square(bbox_pred-bbox_target): 求每个数的平方值
    square_error = tf.square(bbox_pred-bbox_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    with tf.Session() as sess:
        print("bbox_pred-bbox_target:%s"%(sess.run(bbox_pred-bbox_target)))
        print("square_error:%s" % (sess.run(square_error)))
    # 计算pos样本和part样本的数量
    num_valid = tf.reduce_sum(valid_inds)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    # 去掉neg样本和landmark样本的平方和
    square_error = square_error*valid_inds
    # 获取前K个样本的索引,K为pos和part样本的数量
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    # 将所有pos样本和part样本的平方和提取出来
    square_error = tf.gather(square_error, k_index)
    # 返回均值
    return tf.reduce_mean(square_error)

bbox_pred = tf.random_uniform([2,4],10,100,seed = 100)
bbox_target = tf.random_uniform([2,4],15,150,seed = 100)
with tf.Session() as sess:
    print("cls_prob:%s"%(sess.run(bbox_pred)))
label = np.array([1,0])
bbox_ohem(bbox_pred,bbox_target,label)

在这里插入图片描述

landmark_ohem:作用就是返回landmark的损失,用的是landmark样本。

def landmark_ohem(landmark_pred,landmark_target,label):
    '''

    :param landmark_pred:
    :param landmark_target:
    :param label:
    :return: mean euclidean loss
    '''
    #keep label =-2  then do landmark detection
    ones = tf.ones_like(label,dtype=tf.float32)
    zeros = tf.zeros_like(label,dtype=tf.float32)
    valid_inds = tf.where(tf.equal(label,-2),ones,zeros)
    square_error = tf.square(landmark_pred-landmark_target)
    square_error = tf.reduce_sum(square_error,axis=1)
    num_valid = tf.reduce_sum(valid_inds)
    #keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32)
    keep_num = tf.cast(num_valid, dtype=tf.int32)
    square_error = square_error*valid_inds
    _, k_index = tf.nn.top_k(square_error, k=keep_num)
    square_error = tf.gather(square_error, k_index)
    return tf.reduce_mean(square_error)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/139359.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • Pytest(6)重复运行用例pytest-repeat

    Pytest(6)重复运行用例pytest-repeat前言平常在做功能测试的时候,经常会遇到某个模块不稳定,偶然会出现一些bug,对于这种问题我们会针对此用例反复执行多次,最终复现出问题来。自动化运行用例时候,也会出现偶然的bug,可以针对单个用例,

  • 【并发缺陷】data race数据竞争、atomicity violation原子违背、order violation顺序违背

    【并发缺陷】data race数据竞争、atomicity violation原子违背、order violation顺序违背三类均是跟共享变量的内存访问有关的缺陷。对于并发缺陷的分类目前国内许多是分死锁、数据竞争、原子违背、顺序违背。或者在并发缺陷中又细分concurrencyvulnerability:死锁和数据竞争。感觉各个作者有自己的分类方法????以下引用的中文解释来自<并发缺陷暴露、检测与规避研究综述>哈工大的苏小红老师实验室发表在2015年计算机学报上目前找到外文文献分为7类。其他四类…

    2022年10月29日
  • 通过nginx转发WebSocket

    通过nginx转发WebSocket通过nginx请求wensocket的时候需要修改配置文件,对于websocket请求需要特殊处理一下,需要在conf配置文件中添加一些配置:server{listen8080;server_nametest.com;add_header’Access-Control-Allow-Origin”*’always;add_header’Access-Control-Allow-Credentials”true’;add_header’A

    2022年10月18日
  • vue取消eslint规范_vue运行eslint报错

    vue取消eslint规范_vue运行eslint报错关闭eslint代码规范检查

  • socket常用函数_socket recv函数

    socket常用函数_socket recv函数摘要在linux下,使用socketpair函数能够创建一对套节字进行进程间通信(IPC)。函数原形:#include&lt;sys/types.h&gt;#include&lt;sys/socket.h&gt;intsocketpair(intdomain,inttype,intprotocol,intsv[2]);参数1(domain):表示协…

    2022年10月14日
  • beanutils工具类_beanutils.copyproperties忽略null

    beanutils工具类_beanutils.copyproperties忽略null什么是BeanUtils工具BeanUtils工具是一种方便我们对JavaBean进行操作的工具,是Apache组织下的产品。BeanUtils工具一般可以方便javaBean的哪些操作?1)bean

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号