LSTM 时间序列预测 matlab

由于参加了一个小的课题,是关于时间序列预测的。平时习惯用matlab,网上这种资源就比较少。借鉴了 http://blog.csdn.net/u010540396/article/details/52797489 的内容,稍微修改了一下程序。程序说明:DATA.mat是一行时序值,numdely是用前numdely个点预测当前点,cell_num是隐含层的数目,cos

大家好,又见面了,我是你们的朋友全栈君。

由于参加了一个小的课题,是关于时间序列预测的。平时习惯用matlab, 网上这种资源就比较少。

借鉴了  http://blog.csdn.net/u010540396/article/details/52797489  的内容,稍微修改了一下程序。

程序说明:DATA.mat 是一行时序值,

numdely 是用前numdely个点预测当前点,cell_num是隐含层的数目,cost_gate 是误差的阈值。

直接在命令行输入RunLstm(numdely,cell_num,cost_gate)即可。

function [r1, r2] = RunLstm(numdely,cell_num,cost_gate)
%% 数据加载,并归一化处理
figure;
[train_data,test_data]=LSTM_data_process(numdely);
data_length=size(train_data,1)-1;
data_num=size(train_data,2);
%% 网络参数初始化
% 结点数设置
input_num=data_length;
% cell_num=5;
output_num=1;
% 网络中门的偏置
bias_input_gate=rand(1,cell_num);
bias_forget_gate=rand(1,cell_num);
bias_output_gate=rand(1,cell_num);
%网络权重初始化
ab=20;
weight_input_x=rand(input_num,cell_num)/ab;
weight_input_h=rand(output_num,cell_num)/ab;
weight_inputgate_x=rand(input_num,cell_num)/ab;
weight_inputgate_c=rand(cell_num,cell_num)/ab;
weight_forgetgate_x=rand(input_num,cell_num)/ab;
weight_forgetgate_c=rand(cell_num,cell_num)/ab;
weight_outputgate_x=rand(input_num,cell_num)/ab;
weight_outputgate_c=rand(cell_num,cell_num)/ab;
%hidden_output权重
weight_preh_h=rand(cell_num,output_num);
%网络状态初始化
% cost_gate=0.25;
h_state=rand(output_num,data_num);
cell_state=rand(cell_num,data_num);
%% 网络训练学习
for iter=1:100
    yita=0.01;            %每次迭代权重调整比例
    for m=1:data_num
        %前馈部分
        if(m==1)
            gate=tanh(train_data(1:input_num,m)'*weight_input_x);
            input_gate_input=train_data(1:input_num,m)'*weight_inputgate_x+bias_input_gate;
            output_gate_input=train_data(1:input_num,m)'*weight_outputgate_x+bias_output_gate;
            for n=1:cell_num
                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
            end
            forget_gate=zeros(1,cell_num);
            forget_gate_input=zeros(1,cell_num);
            cell_state(:,m)=(input_gate.*gate)';
        else
            gate=tanh(train_data(1:input_num,m)'*weight_input_x+h_state(:,m-1)'*weight_input_h);
            input_gate_input=train_data(1:input_num,m)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;
            forget_gate_input=train_data(1:input_num,m)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;
            output_gate_input=train_data(1:input_num,m)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;
            for n=1:cell_num
                input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
                forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
                output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
            end
            cell_state(:,m)=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';   
        end
        pre_h_state=tanh(cell_state(:,m)').*output_gate;
        h_state(:,m)=(pre_h_state*weight_preh_h)'; 
    end
    % 误差的计算
%     Error=h_state(:,m)-train_data(end,m);
    Error=h_state(:,:)-train_data(end,:);
    Error_Cost(1,iter)=sum(Error.^2);
    if Error_Cost(1,iter) < cost_gate
            iter
        break;
    end
                 [ weight_input_x,...
                weight_input_h,...
                weight_inputgate_x,...
                weight_inputgate_c,...
                weight_forgetgate_x,...
                weight_forgetgate_c,...
                weight_outputgate_x,...
                weight_outputgate_c,...
                weight_preh_h ]=LSTM_updata_weight(m,yita,Error,...
                                                   weight_input_x,...
                                                   weight_input_h,...
                                                   weight_inputgate_x,...
                                                   weight_inputgate_c,...
                                                   weight_forgetgate_x,...
                                                   weight_forgetgate_c,...
                                                   weight_outputgate_x,...
                                                   weight_outputgate_c,...
                                                   weight_preh_h,...
                                                   cell_state,h_state,...
                                                   input_gate,forget_gate,...
                                                   output_gate,gate,...
                                                   train_data,pre_h_state,...
                                                   input_gate_input,...
                                                   output_gate_input,...
                                                   forget_gate_input);


end
%% 绘制Error-Cost曲线图
for n=1:1:iter
    semilogy(n,Error_Cost(1,n),'*');
    hold on;
    title('Error-Cost曲线图');   
end
%% 数据检验
%数据加载
test_final=test_data;
test_final=test_final/sqrt(sum(test_final.^2));
total = sqrt(sum(test_data.^2));
test_output=test_data(:,end);
%前馈
m=data_num;
gate=tanh(test_final(1:input_num)'*weight_input_x+h_state(:,m-1)'*weight_input_h);
input_gate_input=test_final(1:input_num)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;
forget_gate_input=test_final(1:input_num)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;
output_gate_input=test_final(1:input_num)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;
for n=1:cell_num
    input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
    forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
    output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
end
cell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';
pre_h_state=tanh(cell_state_test').*output_gate;
h_state_test=(pre_h_state*weight_preh_h)'* total;
test_output(end);
test = sprintf('----Test result is %s----' ,num2str(h_state_test));
true = sprintf('----True result is %s----' ,num2str(test_output(end)));
disp(test);
disp(true);

function [train_data,test_data]=LSTM_data_process(numdely)

load('DATA.mat');
numdata = size(a,1);
numsample = numdata - numdely - 1;
train_data = zeros(numdely+1, numsample);
test_data = zeros(numdely+1,1);

for i = 1 :numsample
    train_data(:,i) = a(i:i+numdely)';
end

test_data = a(numdata-numdely: numdata);

data_length=size(train_data,1);          
data_num=size(train_data,2);           
% 
%%归一化过程
for n=1:data_num
    train_data(:,n)=train_data(:,n)/sqrt(sum(train_data(:,n).^2));  
end
% for m=1:size(test_data,2)
%     test_data(:,m)=test_data(:,m)/sqrt(sum(test_data(:,m).^2));
% end

function [   weight_input_x,weight_input_h,weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h ]=LSTM_updata_weight(n,yita,Error,...
                                                   weight_input_x, weight_input_h, weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h,...
                                                   cell_state,h_state,input_gate,forget_gate,output_gate,gate,train_data,pre_h_state,input_gate_input, output_gate_input,forget_gate_input)

data_length=size(train_data,1) - 1;
data_num=size(train_data,2);
weight_preh_h_temp=weight_preh_h;


%%% 权重更新函数
input_num=data_length;
cell_num=size(weight_preh_h_temp,1);
output_num=1;

%% 更新weight_preh_h权重
for m=1:output_num
    delta_weight_preh_h_temp(:,m)=2*Error(m,1)*pre_h_state;
end
weight_preh_h_temp=weight_preh_h_temp-yita*delta_weight_preh_h_temp;

%% 更新weight_outputgate_x
for num=1:output_num
    for m=1:data_length
        delta_weight_outputgate_x(m,:)=(2*weight_preh_h(:,num)*Error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train_data(m,n);
    end
    weight_outputgate_x=weight_outputgate_x-yita*delta_weight_outputgate_x;
end
%% 更新weight_inputgate_x
for num=1:output_num
for m=1:data_length
    delta_weight_inputgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train_data(m,n);
end
weight_inputgate_x=weight_inputgate_x-yita*delta_weight_inputgate_x;
end


if(n~=1)
    %% 更新weight_input_x
    temp=train_data(1:input_num,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
    for num=1:output_num
    for m=1:data_length
        delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
    end
    weight_input_x=weight_input_x-yita*delta_weight_input_x;
    end
    %% 更新weight_forgetgate_x
    for num=1:output_num
    for m=1:data_length
        delta_weight_forgetgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train_data(m,n);
    end
    weight_forgetgate_x=weight_forgetgate_x-yita*delta_weight_forgetgate_x;
    end
    %% 更新weight_inputgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_inputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
    end
    weight_inputgate_c=weight_inputgate_c-yita*delta_weight_inputgate_c;
    end
    %% 更新weight_forgetgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_forgetgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
    end
    weight_forgetgate_c=weight_forgetgate_c-yita*delta_weight_forgetgate_c;
    end
    %% 更新weight_outputgate_c
    for num=1:output_num
    for m=1:cell_num
        delta_weight_outputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
    end
    weight_outputgate_c=weight_outputgate_c-yita*delta_weight_outputgate_c;
    end
    %% 更新weight_input_h
    temp=train_data(1:input_num,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;
    for num=1:output_num
    for m=1:output_num
        delta_weight_input_h(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
    end
    weight_input_h=weight_input_h-yita*delta_weight_input_h;
    end
else
    %% 更新weight_input_x
    temp=train_data(1:input_num,n)'*weight_input_x;
    for num=1:output_num
    for m=1:data_length
        delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);
    end
    weight_input_x=weight_input_x-yita*delta_weight_input_x;
    end
end
weight_preh_h=weight_preh_h_temp;

end

—————————————2017.08.03 UPDATE—————————————-

代码数据链接:

http://download.csdn.net/detail/u011060119/9919621

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/125730.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • java基础—java中使用final关键字的总结

    有时候我,们希望某些东西是亘古不变的,可以使用final关键字完成这个重任!final学习总结:1:final + 属性如果属性是基本数据类型(byte 字节型short 短整型int 普通整型char 字符型float 浮点型long 长整型double 双精度),则变为常量,其值不能被更改;如果属性是引用类型,则引用地址不能被更改。(final 修饰一个对象,那么这个对象的引用地址

  • eclipse android环境搭建,Eclipse Android开发环境搭建教程

    eclipse android环境搭建,Eclipse Android开发环境搭建教程EclipseAndroid开发环境搭建教程是本文要介绍的内容,主要是来了解并学习EclipseAndroid环境搭建的过程,具体关于EclipseAndroid内容的详解来看本文。EclipseAndroid开发准备:1、javasdk2、eclipsehttp://www.eclipse.org/downloads/3、android-sdk-windows安装好java环境和ecli…

  • 继电器驱动原理详解(Relay)[通俗易懂]

    继电器驱动原理详解(Relay)[通俗易懂]文章目录继电器内部结构继电器工作原理继电器应用入门进阶这样控制方式的好处继电器使用时注意事项毕设答辩常见问题1、为什么要在继电器线圈上并联一个二极管呢?2、并联的二极管为什么选择开关速度快的?3、电磁继电器和固态继电器(SSR)有什么区别啊?4、不加三极管,直接用单片机的IO提供继电器线圈的电流可以吗?5、继电器使用单片机的高电平触发好呢还是低电平触发好呢?电磁继电器(electromagneticrelay)是一种电子控制器件,它具有控制系统(输入回路)和被控制系统(输出回路),通常应用于自动控制电路

  • Pandas merge函数「建议收藏」

    Pandas merge函数「建议收藏」[toc]函数原型pd.merge(left,right,how=’inner’,on=None,left_on=None,right_on=None,left_index=False,right_index=False,sort=True,suffixes=(‘_x’,’_y’),copy=True,indicator=False,validate=None)参数left:拼接的左侧DataFrame对象r.

  • strstr(str1,str2)函数使用时注意事项

    strstr(str1,str2)函数使用时注意事项可能有的人还没听过strstr函数,个人认为这个一个很实用的函数,strstr(str1,str2)函数是字符串处理函数之一,位于头文件“string.h”中。对于处理字符串的一些问题有很大的帮助。定义:strstr(str1,str2)函数用于判断字符串str2是否是str1的子串。如果是,则该函数返回str2在str1中首次出现的地址;否则,返回NULL。定义说的有点羞涩难懂。举个例子就…

  • 华为面试内幕[通俗易懂]

    华为面试内幕[通俗易懂]华为面试内幕–写给想进华为的毕业生的话华为公司的面试过程分为5个环节:一面(技术)、二面(集体)、三面(性格测试)、四面(英语)、五面(综合)。通关者方可录用。这个流程表面上设计挺严谨的,还是某个著名公司设计的流程。但实际上华为

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号