Python实现softmax函数「建议收藏」

Python实现softmax函数:PS:为了避免求exp(x)出现溢出的情况,一般需要减去最大值。#-*-coding:utf-8-*-importtensorflowastfimportnumpyasnpdefsoftmax(x,axis=1):#计算每行的最大值row_max=x.max(axis=axis)…

大家好,又见面了,我是你们的朋友全栈君。

 Python实现softmax函数 :

Python实现softmax函数「建议收藏」

PS:为了避免求exp(x)出现溢出的情况,一般需要减去最大值。

# -*-coding: utf-8 -*-

import tensorflow as tf
import numpy as np

def softmax(x, axis=1):
    # 计算每行的最大值
    row_max = x.max(axis=axis)

    # 每行元素都需要减去对应的最大值,否则求exp(x)会溢出,导致inf情况
    row_max=row_max.reshape(-1, 1)
    x = x - row_max

    # 计算e的指数次幂
    x_exp = np.exp(x)
    x_sum = np.sum(x_exp, axis=axis, keepdims=True)
    s = x_exp / x_sum
    return s


A = [[1, 1, 5, 3],
     [0.2, 0.2, 0.5, 0.1]]
A= np.array(A)
axis = 1  # 默认计算最后一维

# [1]使用自定义softmax
s1 = softmax(A, axis=axis)
print("s1:{}".format(s1))


#[2]使用TF的softmax
with tf.Session() as sess:
    tf_s2=tf.nn.softmax(A, axis=axis)
    s2=sess.run(tf_s2)
    print("s2:{}".format(s2))

C++实现Softmax函数

template<typename _Tp>
int softmax(const _Tp* src, _Tp* dst, int length)
{
//    double max = 0.0;
//    double sum = 0.0;
//
//    for (int i = 0; i<k; i++) if (max < x[i]) max = x[i];
//    for (int i = 0; i<k; i++) {
//        x[i] = exp(x[i] - max);
//        sum += x[i];
//    }
//    for (int i = 0; i<k; i++) x[i] /= sum;
    //为了避免溢出,需要减去最大值
    const _Tp max_value = *std::max_element(src, src + length);
    _Tp denominator{ 0 };

    for (int i = 0; i < length; ++i) {
        dst[i] = std::exp(src[i] - max_value);
        denominator += dst[i];
    }

    for (int i = 0; i < length; ++i) {
        dst[i] /= denominator;
    }
    return 0;
}
std::vector<float> output_vector;
std::vector<float> preds;
softmax(output_vector.data(), preds.data(),output_vector.size());

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/129556.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • c语言 switch错误用法,C语言switch语句的详细用法[通俗易懂]

    c语言 switch错误用法,C语言switch语句的详细用法[通俗易懂]C语言还为多分支选择提供了另一个switch语句,其一般形式为:开关(表达式){案例常量表达式1:语句1;案例常量表达式2:语句2;…条件常量表达式n:语句n;默认值:语句n+1;}语义是:计算表达式的值.将其与后续常量表达式的值一一比较.当表达式的值等于常量表达式的值时,将执行后续语句,然后不进行判断,并且在个案之后的所有后续语句将继续.如果在所有情况下表达式的值都与常量…

  • 这一次,终于系统的学习了 JVM 内存结构

    这一次,终于系统的学习了 JVM 内存结构最近在看《JAVA并发编程实践》这本书,里面涉及到了Java内存模型,通过Java内存模型顺理成章的来到的JVM内存结构,关于JVM内存结构的认知还停留在上大学那会的课堂上,一直没有系统的学习这一块的知识,所以这一次我把《深入理解Java虚拟机JVM高级特性与最佳实践》、《Java虚拟机规范JavaSE8版》这两本书中关于JVM内存结构的部分都看了一遍,算是…

  • golang []byte和string相互转换

    golang []byte和string相互转换测试例子:packagemainimport(“fmt”)funcmain(){str2:=”hello”data2:=[]byte(str2)fmt.Println(data2)str2=string(data2[:])fmt.Println(str2)}

  • 未连接到互联网代理服务器出现问题或地址有误(代理服务器的ip地址是多少)

    今天遇到一个问题:【校园网】打开电脑其实网络无法连接到安全代理服务器,本地IP地址非法,无法打开本地连接属性既看不到TCP/IP地址框。解决方法:1.重新启动计算机后发现网卡被禁用,重新启用就好了。

  • gif录屏与gif图片合成工具「建议收藏」

    gif录屏与gif图片合成工具「建议收藏」现在好多gif图片合成是收费的,而且可能还不太好用,这里分析的gif合成软件是个比较老的软件,但是用着还是挺好用的。还有一个录屏软件,录制保存为gif文件。百度网盘分享,无需积分:链接:https://pan.baidu.com/s/1HukTW6yJvqoUiqbzXuY5bQ提取码:pvc4欢迎关注微信公众号,分享更多实用工具:…

  • androidstudio 优化gradle编译效率[通俗易懂]

    androidstudio 优化gradle编译效率

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号