从零和使用mxnet实现softmax分类

1.softmax从零实现(1797,64)(1797,)(1797,10)epoch:50,loss:[1.9941667],accuracy:0.3550361713967724

大家好,又见面了,我是全栈君,今天给大家准备了Idea注册码。

1.softmax从零实现

from mxnet.gluon import data as gdata
from sklearn import datasets
from mxnet import nd,autograd
# 加载数据集
digits = datasets.load_digits()
features,labels = nd.array(digits['data']),nd.array(digits['target'])
print(features.shape,labels.shape)
labels_onehot = nd.one_hot(labels,10)
print(labels_onehot.shape)
(1797, 64) (1797,)
(1797, 10)
class softmaxClassifier:
    def __init__(self,inputs,outputs):
        self.inputs = inputs
        self.outputs = outputs
        
        self.weight = nd.random.normal(scale=0.01,shape=(inputs,outputs))
        self.bias = nd.zeros(shape=(1,outputs))
        self.weight.attach_grad()
        self.bias.attach_grad()
        
    def forward(self,x):
        output = nd.dot(x,self.weight) + self.bias
        return self._softmax(output)
        
    def _softmax(self,x):
        step1 = x.exp()
        step2 = step1.sum(axis=1,keepdims=True)
        return step1 / step2
    
    def _bgd(self,params,learning_rate,batch_size):
        '''
        批量梯度下降
        '''
        for param in params:       # 直接使用mxnet的自动求梯度
            param[:] = param - param.grad * learning_rate / batch_size
            
    def loss(self,y_pred,y):
        return nd.sum((-y * y_pred.log())) / len(y)
            
    def dataIter(self,x,y,batch_size):
        dataset = gdata.ArrayDataset(x,y)
        return gdata.DataLoader(dataset,batch_size,shuffle=True)
    
    def fit(self,x,y,learning_rate,epoches,batch_size):
        for epoch in range(epoches):
            for x_batch,y_batch in self.dataIter(x,y,batch_size):
                with autograd.record():
                    y_pred = self.forward(x_batch)
                    l = self.loss(y_pred,y_batch)
                l.backward()
                self._bgd([self.weight,self.bias],learning_rate,batch_size)
            if epoch % 50 == 0:
                y_all_pred = self.forward(x)
                print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,self.loss(y_all_pred,y),self.accuracyScore(y_all_pred,y)))
            
    def predict(self,x):
        y_pred = self.forward(x)
        return y_pred.argmax(axis=0)
    
    def accuracyScore(self,y_pred,y):
        acc_sum = (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().asscalar()
        return acc_sum / len(y)
sfm_clf = softmaxClassifier(64,10)
sfm_clf.fit(features,labels_onehot,learning_rate=0.1,epoches=500,batch_size=200)
epoch:50,loss:
[1.9941667]
<NDArray 1 @cpu(0)>,accuracy:0.3550361713967724
epoch:100,loss:
[0.37214527]
<NDArray 1 @cpu(0)>,accuracy:0.9393433500278241
epoch:150,loss:
[0.25443634]
<NDArray 1 @cpu(0)>,accuracy:0.9549248747913188
epoch:200,loss:
[0.20699367]
<NDArray 1 @cpu(0)>,accuracy:0.9588202559821926
epoch:250,loss:
[0.1799827]
<NDArray 1 @cpu(0)>,accuracy:0.9660545353366722
epoch:300,loss:
[0.1619963]
<NDArray 1 @cpu(0)>,accuracy:0.9677239844184753
epoch:350,loss:
[0.14888664]
<NDArray 1 @cpu(0)>,accuracy:0.9716193656093489
epoch:400,loss:
[0.13875261]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:450,loss:
[0.13058177]
<NDArray 1 @cpu(0)>,accuracy:0.9760712298274903
epoch:500,loss:
[0.12379646]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933
print('预测结果:',sfm_clf.predict(features[:10]))
print('真实结果:',labels[:10])
预测结果: 
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>
真实结果: 
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>

2.使用mxnet实现softmax分类

from mxnet import gluon,nd,autograd,init
from mxnet.gluon import nn,trainer,loss as gloss,data as gdata
# 定义模型
net = nn.Sequential()
net.add(nn.Dense(10))

# 初始化模型
net.initialize(init=init.Normal(sigma=0.01))

# 损失函数
loss = gloss.SoftmaxCrossEntropyLoss(sparse_label=False)

# 优化算法
optimizer = trainer.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})

# 训练
epoches = 500
batch_size = 200

dataset = gdata.ArrayDataset(features, labels_onehot)
data_iter = gdata.DataLoader(dataset,batch_size,shuffle=True)
for epoch in range(epoches):
    for x_batch,y_batch in data_iter:
        with autograd.record():
            l = loss(net.forward(x_batch), y_batch).sum() / batch_size
        l.backward()
        optimizer.step(batch_size)
    if epoch % 50 == 0:
        y_all_pred = net.forward(features)
        acc_sum = (y_all_pred.argmax(axis=1) == labels_onehot.argmax(axis=1)).sum().asscalar()
        print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,loss(y_all_pred,labels_onehot).sum() / len(labels_onehot),acc_sum/len(y_all_pred)))
epoch:50,loss:
[2.1232333]
<NDArray 1 @cpu(0)>,accuracy:0.24652198107957707
epoch:100,loss:
[0.37193483]
<NDArray 1 @cpu(0)>,accuracy:0.9410127991096272
epoch:150,loss:
[0.25408813]
<NDArray 1 @cpu(0)>,accuracy:0.9543683917640512
epoch:200,loss:
[0.20680156]
<NDArray 1 @cpu(0)>,accuracy:0.9627156371730662
epoch:250,loss:
[0.1799252]
<NDArray 1 @cpu(0)>,accuracy:0.9666110183639399
epoch:300,loss:
[0.16203885]
<NDArray 1 @cpu(0)>,accuracy:0.9699499165275459
epoch:350,loss:
[0.14899409]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:400,loss:
[0.13890252]
<NDArray 1 @cpu(0)>,accuracy:0.9749582637729549
epoch:450,loss:
[0.13076076]
<NDArray 1 @cpu(0)>,accuracy:0.9755147468002225
epoch:500,loss:
[0.1239901]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/120005.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • MySQL中多表删除方法

    MySQL中多表删除方法

  • Python numpy函数:dtype数组元素类型

    Python numpy函数:dtype数组元素类型数组元素的类型通过dtype属性获得。而且,每一种数据类型都有几种字符串表达形式,我们可以使用typeDict字典来查询某种字符串所代表的数据类型,比如“d”和“double”都是float64数据类型:

  • Stm32的GPIO驱动继电器[通俗易懂]

    Stm32的GPIO驱动继电器[通俗易懂]三极管三极管S8050-D:直插的字体面向自己从左到右依次是发射极E,基极B,集电极C类型:NPN集电极耗散功率Pc:0.625W(贴片:0.3W)集电极电流Ic:0.5A集电极-基极电压Vcbo:40V集电极-发射极电压Vceo:25V集电极-发射极饱和电压Vce(sat):0.6V特征频率f:最小150MHz按三极管后缀号分为BCD档贴片为LH档放大倍数:B85-160C120-200D160-300L100-200H200-350上图中,左边的GPIO

  • 数据库 建模_可视化建模与uml

    数据库 建模_可视化建模与uml(1)数据库建模  1,ERWin:CA公司出品的拳头产品,强大的老牌数据库建模工具。它有一个兄弟是BPWin,这个是CASE工具的一个里程碑似的产品。目前的我使用的版本是4.0。ERWin界面相当简洁漂亮,也是采用ER模型,如果你是开发中小型数据库,极力推荐ERWin,它的Diagram给人的感觉十分清晰。在一个实体中,不同的属性类型采用可定制的图标显示,实体与实体的关系也一目了然。当然

    2022年10月29日
  • GIMP 2.10教程「建议收藏」

    GIMP 2.10教程「建议收藏」更新一下(2020-12-27),有大神刚完成人工翻译,质量很好,地址在此:https://www.ycproject.cn/gimp/gimp.html下文可以忽略了GIMP_2.10中文教程(谷歌机翻)GIMP是全平台(桌面)下的Photoshop,专门处理图片的。先放原文地址:https://docs.gimp.org/2.10/zh_CN/(基于2.10.18版)GIMP中文教程太少了,搜了一大圈找到一个靠谱点全一点的,是@笨⼩璀在2014年基于2012年的2.8版翻译的,翻译

  • 安全关机程序[通俗易懂]

    安全关机程序[通俗易懂]安全关机程序最近在实验室用ftp下点东西,但是由于实验室晚上12点就会断电。于是需要在此之前关掉机器,图省事就用WindowsXP自带的计划任务每次设置成11:50就调用“shutdown-s”命令自动关机。但是好几次都发现没法正常关机,第二天早上起来就会检测磁盘。于是就做了个实验,发现确实当使用flashfxp下载东西时,关机会不能正常关机,等待确定终止flashfxp程序。发现原因后,很简单

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号