Python实现softmax函数「建议收藏」

Python实现softmax函数:PS:为了避免求exp(x)出现溢出的情况,一般需要减去最大值。#-*-coding:utf-8-*-importtensorflowastfimportnumpyasnpdefsoftmax(x,axis=1):#计算每行的最大值row_max=x.max(axis=axis)…

大家好,又见面了,我是你们的朋友全栈君。

 Python实现softmax函数 :

Python实现softmax函数「建议收藏」

PS:为了避免求exp(x)出现溢出的情况,一般需要减去最大值。

# -*-coding: utf-8 -*-

import tensorflow as tf
import numpy as np

def softmax(x, axis=1):
    # 计算每行的最大值
    row_max = x.max(axis=axis)

    # 每行元素都需要减去对应的最大值,否则求exp(x)会溢出,导致inf情况
    row_max=row_max.reshape(-1, 1)
    x = x - row_max

    # 计算e的指数次幂
    x_exp = np.exp(x)
    x_sum = np.sum(x_exp, axis=axis, keepdims=True)
    s = x_exp / x_sum
    return s


A = [[1, 1, 5, 3],
     [0.2, 0.2, 0.5, 0.1]]
A= np.array(A)
axis = 1  # 默认计算最后一维

# [1]使用自定义softmax
s1 = softmax(A, axis=axis)
print("s1:{}".format(s1))


#[2]使用TF的softmax
with tf.Session() as sess:
    tf_s2=tf.nn.softmax(A, axis=axis)
    s2=sess.run(tf_s2)
    print("s2:{}".format(s2))

C++实现Softmax函数

template<typename _Tp>
int softmax(const _Tp* src, _Tp* dst, int length)
{
//    double max = 0.0;
//    double sum = 0.0;
//
//    for (int i = 0; i<k; i++) if (max < x[i]) max = x[i];
//    for (int i = 0; i<k; i++) {
//        x[i] = exp(x[i] - max);
//        sum += x[i];
//    }
//    for (int i = 0; i<k; i++) x[i] /= sum;
    //为了避免溢出,需要减去最大值
    const _Tp max_value = *std::max_element(src, src + length);
    _Tp denominator{ 0 };

    for (int i = 0; i < length; ++i) {
        dst[i] = std::exp(src[i] - max_value);
        denominator += dst[i];
    }

    for (int i = 0; i < length; ++i) {
        dst[i] /= denominator;
    }
    return 0;
}
std::vector<float> output_vector;
std::vector<float> preds;
softmax(output_vector.data(), preds.data(),output_vector.size());

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/129556.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • ldc1614 c语言编程,LDC1614读回来的数据为固定值不变[通俗易懂]

    ldc1614 c语言编程,LDC1614读回来的数据为固定值不变[通俗易懂]OtherPartsDiscussedinThread:LDC1614,LDC1314,LDC1614EVM求教一下各位前辈,硬件是用的LDC1614的评估板,ch0和ch1上接了两个线圈,并联的电容为100pF,器件ID和装配ID读出来为0x3055和0x5449,和手册上的一致,而且我写寄存器再读出来数据都是对的,排除了软件驱动上的问题,现在可能是配置上有哪里不对,或者芯片有问题(…

  • 解决 bcm43问题

    解决 bcm43问题

  • git 拉新分支_git基于远程分支新建本地分支

    git 拉新分支_git基于远程分支新建本地分支原文地址:http://www.cnblogs.com/lingear/p/6062093.html开发过程中经常用到从master分支copy一个开发分支,下面我们就用命令行完成这个操作:1.切换到被copy的分支(master),并且从远端拉取最新版本$gitcheckoutmaster$gitpull2.从当前分支拉copy开发分

  • 计算机基础三: 二进制减法实现[通俗易懂]

    计算机基础三: 二进制减法实现[通俗易懂]在上一章中了解了如何实现二进制加法,加法是始终从两个加数的最右列向左列进位计算的,而在减法中没有进位,只有借位.253-176=77上面的式子我们不难算出来,但习惯性的思维让我们用借位的方式求值.在不借位的情况下如何实现计算?借位是很麻烦的事情,虽然我们能够实现它,但这意味着额外的开销.我们将用一个小技巧,让我们避开借位从而实现减法.为了避免借位,我们先从百位最大值999中减去减数,而非从原来的被减数中减去减数.999-176=823这个方法称为对9求补

  • spring boot data jdbc_java连接数据库详细步骤

    spring boot data jdbc_java连接数据库详细步骤Spring Boot入门(五):使用JDBC访问MySql数据库

  • matlab矩阵存为txt_matlab数据批量处理

    matlab矩阵存为txt_matlab数据批量处理fileID=fopen(‘Data.txt’);A=textscan(fileID,’%f%*f%*f%*f%f%f%f’);fclose(fileID);Matrix=cell2mat(A);textscan中,%*f表示不读取该列数据。

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号