基于LSTM的DDPG实现

基于LSTM的DDPG实现最近看了一些大佬的DDPG的实现(其实都是基于莫凡大佬的那个版本),结合我自己的毕设问题,发现只是用普通的全连接网络好像不太稳定,表现也不好,于是尝试了一下试着用一直对序列数据有强大处理能力的lstm来试试(虽然这个已经有人做过了),自己手动实现了一下基于lstm的ddpg,希望各位大佬指导指导。importtorchimporttorch.nnasnnimporttorch.op…

大家好,又见面了,我是你们的朋友全栈君。

这两天实在不想动这个东西,想了想还是毕业要紧。
稍微跟自己搭的环境结合了一下,对于高维的状态输入可以完成训练(但效果没测试,至少跑通了),并且加入了batch训练的过程,根据伯克利课程说明,加入batch的话会让训练方差减小,提升系统的稳定性。但是因为memory那块使用list做的所以取batch的时候过程相当绕(我发现我现在写python代码还是摆脱不了java的影子啊),希望有大佬给我点建议。

最近看了一些大佬的DDPG的实现(其实都是基于莫凡大佬的那个版本),结合我自己的毕设问题,发现只是用普通的全连接网络好像不太稳定,表现也不好,于是尝试了一下试着用一直对序列数据有强大处理能力的lstm来试试(虽然这个已经有人做过了),自己手动实现了一下基于lstm的ddpg,希望各位大佬指导指导。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from Env_2_DDPG import Environment

date_count = 5
date_dim = 6
hide_dim = 10
hide_dim_lstm = 100
gamma = 0.8
lr_miu = 0.01
lr_Q = 0.02
tau = 0.01
trans_num = 10
batch_size = 4
MAX_EPISODES = 10
MAX_EP_STEPS = 500
memory_size = 10


class My_loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.mean(-x)


class A_net(nn.Module):
    def __init__(self):
        super(A_net, self).__init__()
        self.state_dim = date_dim*date_count
        self.net = nn.LSTM(date_dim, hide_dim_lstm)
        self.reg = nn.Linear(date_count*hide_dim_lstm, 1)

    def forward(self, state):
        state = torch.Tensor(state)
        # state = torch.unsqueeze(state, 0)
        x, _ = self.net(state)
        s, b, h = x.shape
        x = x.view(s, b*h)
        x = self.reg(x)
        x = x.view(s, -1)
        return x


class C_net(nn.Module):
    def __init__(self):
        super(C_net, self).__init__()
        self.state_dim = date_count*date_dim
        self.net = nn.LSTM(date_dim, hide_dim_lstm)
        self.lstm_res = nn.Linear(date_count*hide_dim_lstm, date_count)
        self.action_net = nn.Linear(1, date_count)
        self.reg = nn.Linear(date_count, 1)

    def forward(self, state, action):
        state = torch.Tensor(state)
        x1, _ = self.net(state)
        x2 = self.action_net(action)
        s, b, h = x1.shape
        x1 = x1.view(s, b*h)
        x1 = self.lstm_res(x1)
        x = self.reg(x1+x2)
        x = x.view(s, -1)
        return x


class ddpg_lstm(nn.Module):
    def __init__(self):
        super(ddpg_lstm, self).__init__()
        self.miu_net = A_net()
        self.miu_pie = A_net()
        self.Q_net = C_net()
        self.Q_pie = C_net()
        self.optim_miu = optim.SGD(self.miu_net.parameters(), lr=lr_miu, momentum=0.5)
        self.optim_Q = optim.Adam(self.Q_net.parameters(), lr=lr_Q)
        self.loss_Q = nn.MSELoss()
        self.memory = list()
        self.index = 0

    def learn(self, tra):
        s = torch.Tensor(tra[0])
        s = s.reshape(batch_size, date_count, date_dim)
        r = torch.Tensor(tra[1])
        a = torch.Tensor(tra[2])
        s_ = torch.Tensor(tra[3])
        s_ = s_.reshape(batch_size, date_count, date_dim)
        a_ = self.miu_pie(s_)
        y = r + gamma*self.Q_pie(s_, a_)
        # a = torch.Tensor(np.array([a]))
        q = self.Q_net(s, a)

        self.optim_Q.zero_grad()
        q_loss = self.loss_Q(y, q)
        q_loss.backward(retain_graph=True)
        self.optim_Q.step()

        self.optim_miu.zero_grad()
        _miu_loss = My_loss()
        miu_loss = _miu_loss(q)
        miu_loss.backward()
        self.optim_miu.step()

    def soft_update(self):
        self.miu_pie.net.weight.data = tau*self.miu_net.net.weight.data + (1-tau)*self.miu_pie.net.weight.data
        self.miu_pie.reg.weight.data = tau*self.miu_net.reg.weight.data + (1-tau)*self.miu_pie.reg.weight.data
        self.Q_pie.net.weight.data = tau*self.Q_net.net.weight.data + (1-tau)*self.Q_pie.net.weight.data
        self.Q_pie.action_net.weight.data = tau*self.Q_net.action_net.weight.data + (1-tau)*self.Q_pie.action_net.weight.data
        self.Q_pie.reg.weight.data = tau*self.Q_net.reg.weight.data + (1-tau)*self.Q_pie.reg.weight.data

        # for x in self.miu_net.state_dict().keys():
        # eval('self.miu_net.' + x + '.data.mul_((1-TAU))')
        # eval('self.miu_net.' + x + '.data.add_(TAU*self.Actor_eval.' + x + '.data)')
        # for x in self.Critic_target.state_dict().keys():
        # eval('self.Critic_target.' + x + '.data.mul_((1-TAU))')
        # eval('self.Critic_target.' + x + '.data.add_(TAU*self.Critic_eval.' + x + '.data)')

    def store_trans(self, s, r, a, s_):
        temp = list()
        temp.append(s)
        temp.append(r)
        temp.append(a)
        temp.append(s_)
        self.memory.append(temp)
        self.index += 1
        if self.index > trans_num:
            del self.memory[0]

    def train_model(self):
        batch = np.random.choice(memory_size, batch_size)
        bs = np.zeros((batch_size, date_count*date_dim))
        br = np.zeros((batch_size, 1))
        ba = np.zeros((batch_size, 1))
        bs_ = np.zeros((batch_size, date_count*date_dim))
        index_ = 0
        for item in batch:
            bs[index_] = (np.array(self.memory[item][0])).reshape(date_dim*date_count)
            br[index_] = self.memory[item][1]
            ba[index_] = self.memory[item][2]
            bs_[index_] = (np.array(self.memory[item][3])).reshape(date_dim*date_count)
            index_ += 1
            # self.learn(self.memory[item])
        self.learn([bs, br, ba, bs_])
        print('over')
        # self.learn(tra)
            # self.learn()

    def next_action(self, state):
        action = self.miu_net(state)
        return action.detach()


ddpg = ddpg_lstm()
env = Environment()


def train():
    for i in range(MAX_EPISODES):
        # if i > 50:
        # print(i)
        s = env.reset()
        ep_reward = 0
        for j in range(MAX_EP_STEPS):
            a = ddpg.next_action(s)
            s_, r = env.step(a)
            ddpg.store_trans(s, a, r/10, s_)
            if ddpg.index > memory_size:
                ddpg.train_model()
            s = s_
            ep_reward += r
        if i % 1 == 0 and i > 0:
            print('Episode:', i, ' Reward: %i' % int(ep_reward))
    torch.save(ddpg, 'ddpg2.pt')





需要注意的是我这个没有对数据进行处理,主要针对的是单个数据,还没有针对batch数据,因此在数据送入lstm模型之前手动加了个torch.unsqueeze()强行扩展一个维度。
目前程序处在能跑通的阶段,后续有时间的话继续更新吧。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/151665.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 引起cpu流水线阻塞的三个原因

    引起cpu流水线阻塞的三个原因1、多个任务在同一时间周期内争用同一个流水段(资源冲突)例如,假如在指令流水线中,如果数据和指令是放在同一个储存器中,并且访问接口也只有一个,那么,两条指令就会争用储存器;在一些算数流水线中,有些运算会同时访问一个运算部件。2、数据依赖(数据相关)比如,A运算必须得到B运算的结果,但是,B运算还没有开始,A运算动作就必须等待,直到A运算完成,两次运算不能同时执行。3、 条件转移的影响(条件转移)如…

  • invocationHandler_handlermapping原理

    invocationHandler_handlermapping原理动态代理是很多框架和技术的基础,spring的AOP实现就是基于动态代理实现的。了解动态代理的机制对于理解AOP的底层实现是很有帮助的。      查看doc文档就可以知道,在java.lang.reflect包中有一个叫Proxy的类。下面是doc文档对Proxy类的说明:      “Adynamicproxyclass(simplyreferredtoasa

    2022年10月28日
  • java环境变量 的配置与详解(全网最详细教程)

    java环境变量 的配置与详解(全网最详细教程)笔者这学期开始学习java课程,学习java开发首先需要配置java运行环境变量。虽然上课老师也讲了如何配置java环境变量,可是笔者的同学还是有好多都不会配置,所以笔者最近配置了特别多次java环境变量。如下笔者详细解释从JDK安装到环境变量的装配。目录 JDK的下载与安装 配置java环境变量JAVA_HOME变量Path变量ClassPath变量classpath…

  • vue devtools如何使用调试_千牛提示opendevtools

    vue devtools如何使用调试_千牛提示opendevtoolsWriteByMonkeyfly以下内容均为原创,如需转载请注明出处。前提今天准备开始学vue.js了,不为别的,只是因为我女朋友毕设项目的前端是使用vue开发的,而我作为一个前端开发却无能为力,你说可不可笑。她需要一个会vue的前端帮她做界面,而我虽然身为一个前端开发,但是并不会vue,所以作为男朋友的我本身就很自责。现阶段的情况是:我只是知道有这些框架,再加上公司的项…

  • 手机来电通核心模块——归属地数据库设计(Winsym原创)「建议收藏」

    手机来电通核心模块——归属地数据库设计(Winsym原创)「建议收藏」说到Symbian,确实让人头痛。不仅开发平台和SDK版本众多,难以选择,而且对程序员确实要求很高,光是SymbianC++的熟悉就要花上很长时间,更麻烦的是测试和调试。模拟器只能提供一部分功能,和电话通信有关的全部要在真机上测试。很多时候,在模拟器上能跑的代码,放到真机上就不行了,这其中的心酸想必开发过得朋友深有体会。小弟我因为工程实践项目的要求,和几位嵌入式的高手一起搞了Symbian来电通项目。其实来电通项目已经有很多人做了,比较有名的是CallMaster和柳丁,但是这方面的关键技术和源码至今没有

  • thinkPHP的优缺点

    thinkPHP的优缺点

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号