随机梯度下降法概述与实例分析_梯度下降法推导

随机梯度下降法概述与实例分析_梯度下降法推导机器学习算法中回归算法有很多,例如神经网络回归算法、蚁群回归算法,支持向量机回归算法等,其中也包括本篇文章要讲述的梯度下降算法,本篇文章将主要讲解其基本原理以及基于SparkMLlib进行实例示范,不足之处请多多指教。梯度下降算法包含多种不同的算法,有批量梯度算法,随机梯度算法,折中梯度算法等等。对于随机梯度下降算法而言,它通过不停的判断和选择当前目标下最优的路径,从而能够在最短路径…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

机器学习算法中回归算法有很多,例如神经网络回归算法、蚁群回归算法,支持向量机回归算法等,其中也包括本篇文章要讲述的梯度下降算法,本篇文章将主要讲解其基本原理以及基于Spark MLlib进行实例示范,不足之处请多多指教。

梯度下降算法包含多种不同的算法,有批量梯度算法,随机梯度算法,折中梯度算法等等。对于随机梯度下降算法而言,它通过不停的判断和选择当前目标下最优的路径,从而能够在最短路径下达到最优的结果。我们可以在一个人下山坡为例,想要更快的到达山低,最简单的办法就是在当前位置沿着最陡峭的方向下山,到另一个位置后接着上面的方式依旧寻找最陡峭的方向走,这样每走一步就停下来观察最下路线的方法就是随机梯度下降算法的本质。
这里写图片描述

随机梯度下降算法理论基础

在线性回归中,我们给出回归方程,如下所示:
这里写图片描述
我们知道,对于最小二乘法要想求得最优变量就要使得计算值与实际值的偏差的平方最小。而随机梯度下降算法对于系数需要通过不断的求偏导求解出当前位置下最优化的数据,那么梯度方向公式推导如下公式,公式中的θ会向着梯度下降最快的方向减少,从而推断出θ的最优解。

这里写图片描述

因此随机梯度下降法的公式归结为通过迭代计算特征值从而求出最合适的值。θ的求解公式如下。
这里写图片描述

α是下降系数,即步长,学习率,通俗的说就是计算每次下降的幅度的大小,系数越大每次计算的差值越大,系数越小则差值越小,但是迭代计算的时间也会相对延长。θ的初值可以随机赋值,比如下面的例子中初值赋值为0。

Spark MLlib随机梯度下降算法实例

下面使用Spark MLlib来迭代计算回归方程y=2x的θ最优解,代码如下:

package cn.just.shinelon.MLlib.Algorithm

import java.util

import scala.collection.immutable.HashMap

/**
  * 随机梯度下降算法实战
  * 随机梯度下降算法:最短路径下达到最优结果
  * 数学表达公式如下:
  * f(θ)=θ0x0+θ1x1+θ2x2+...+θnxn
  * 对于系数要通过不停地求解出当前位置下最优化的数据,即不停对系数θ求偏导数
  * 则θ求解的公式如下:
  * θ=θ-α(f(θ)-yi)xi
  * 公式中α是下降系数,即每次下降的幅度大小,系数越大则差值越小,系数越小则差值越小,但是计算时间也相对延长
  */
object SGD {
  var data=HashMap[Int,Int]()         //创建数据集
  def getdata():HashMap[Int,Int]={
    for(i <- 1 to 50){                //创建50个数据集
      data += (i->(2*i))              //写入公式y=2x
    }
    data                              //返回数据集
  }

  var θ:Double=0                        //第一步 假设θ为0
  var α:Double=0.1                      //设置步进系数

  def sgd(x:Double,y:Double)={        //随机梯度下降迭代公式
    θ=θ-α*((θ*x)-y)                 //迭代公式
  }

  def main(args: Array[String]): Unit = {
    val dataSource=getdata()          //获取数据集
    dataSource.foreach(myMap=>{       //开始迭代
      sgd(myMap._1,myMap._2)          //输入数据
    })
    println("最终结果值θ为:"+θ)
  }
}

需要注意的是随着步长系数增大以及数据量的增大,θ值偏差越来越大。同时这里也遗留下一个问题,当数据量大到一定程度,为什么θ值会为NaN,笔者心中有所疑惑,如果哪位大佬有想法可以留言探讨,谢谢!!!


如果你想和我一起学习交流,共同进步,欢迎加群:
在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/195274.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • idea2021 激活码【中文破解版】

    (idea2021 激活码)最近有小伙伴私信我,问我这边有没有免费的intellijIdea的激活码,然后我将全栈君台教程分享给他了。激活成功之后他一直表示感谢,哈哈~https://javaforall.cn/100143.htmlIntelliJ2021最新激活注册码,破解教程可免费永久激活,亲测有效,上面是详细链接哦~S32P…

  • Git安装配置教程

    Git安装配置教程1.Git简介Git是一个开源的分布式版本控制系统,可以有效、高速的处理从很小到非常大的项目版本管理1。Git是LinusTorvalds为了帮助管理Linux内核开发而开发的一个开放源码的版本控制软件。2.Git工作示意图3.Windows下安装Git3.1Git下载下载地址:https://git-for-windows.github.io/下载有时候很慢,请耐心

  • mybatis的逆向工程怎么实现_mybatis逆向工程原理

    mybatis的逆向工程怎么实现_mybatis逆向工程原理复习逆向工程的使用,记录方便以后参考mybatis,一个相对于hibernate的轻量级DAO框架,它的逆向工程可以很方便的从数据库到生成对应的entity和mapper接口。首先准备:准备pom.xml引入mybatisgenerator的jar,若不是maven工程,可以把jar下载下来导进工程的lib下即可01.引入依赖(加入jar)进入ma

  • constraint使用方法总结

    constraint使用方法总结

    2021年11月14日
  • rust-vmm 学习

    rust-vmm 学习V0.1.0featurebaseknowledge:ArchitectureoftheKernel-basedVirtualMachine(KVM)用rust-vmm打造未来的虚拟化架构KVM内核文档阅读笔记<MasteringKVMVirtualization>:第二章KVM内部原理UsingtheKVMAPI(org)…

  • 列一些Hbase面试题「建议收藏」

    列一些Hbase面试题「建议收藏」HbaseHbase是怎么写数据的?HDFS和HBase各自使用场景Hbase的存储结构热点现象(数据倾斜)怎么产生的,以及解决方法有哪些HBase的rowkey设计原则…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号