MCTS学习笔记「建议收藏」

MCTS学习笔记「建议收藏」MCTS树学习MCTS,即蒙特卡罗树搜索,是一类搜索算法树的统称,可以较为有效地解决一些搜索空间巨大的问题。如一个8*8的棋盘,第一步棋有64种着法,那么第二步则有63种,依次类推,假如我们把第一步棋作为根节点,那么其子节点就有63个,再往下的子节点就有62个……如果不加干预,树结构将会繁杂,MCTS采用策略来对获胜性较小的着法不予考虑,如第二步的63种着法中有10种是不可能胜利的,那么这十个…

大家好,又见面了,我是你们的朋友全栈君。

 

MCTS树学习

MCTS,即蒙特卡罗树搜索,是一类搜索算法树的统称,可以较为有效地解决一些搜索空间巨大的问题。

如一个8*8的棋盘,第一步棋有64种着法,那么第二步则有63种,依次类推,假如我们把第一步棋作为根节点,那么其子节点就有63个,再往下的子节点就有62个……

如果不加干预,树结构将会繁杂,MCTS采用策略来对获胜性较小的着法不予考虑,如第二步的63种着法中有10种是不可能胜利的,那么这十个子节点不予再次分配子节点。

MCTS的主要步骤分为四个:

1, 选择(Selection)

即找一个最好的值得探索的结点,通常是先选择没有探索过的结点,如果都探索过了,再选择UCB值最大的进行选择(UCB是由一系列算法计算得到的值,这里先不详细讲,可以简单视为value)

2, 扩展(Expansion)

已经选择好了需要进行扩展的结点,那么就对其进行扩展,即对其一个子节点最为下一步棋的假设,一般为随机取一个可选的节点进行扩展。

3, 模拟(Simulation)

扩展出了子节点,就可以根据该子节点继续进行模拟了,我们随机选择一个可选的位置作为模拟下一步的落子,将其作为子节点,然后依据该子节点,继续寻找可选的位置作为子节点,依次类推,直到博弈已经判断出了胜负,将胜负信息作为最终得分。

4, 回溯更新(Backpropagation)

将最终的得分累加到父节点,不断从下向上累加更新。

MCTS学习笔记「建议收藏」

 

对于UCB值,计算方法很简单,公式如下:

image

image

其中v’表示当前树节点,v表示父节点,Q表示这个树节点的累计quality值,N表示这个树节点的visit次数,C是一个 常量参数,通常值设为1/√2

接下来再讨论怎么使用Python实现MCTS树。

首先树的每个节点Node需要记录其父节点Node parent,和子节点Node children[],用于计算UCB的这个节点的quality值和visit次数。

    def __init__(self):
        self.parent = None
        self.children = []

        self.visit_times = 0
        self.quality_value = 0.0

        self.state = None

 

state中除了需要记录每一步的选择,还需要记录每一步的层数round值与reward值。
需要注意的是,在模拟的过程中,只有state状态的模拟和更新,更新后记录的是最终的reward状态,而树结构却没有随着模拟的进行而增加结点。
class State(object):
    def __init__(self):
        self.value = 0
        self.round = 0
        self.choices = []

整棵树需要实现的功能则是,在一个环境下,选择出一个最有可能获胜的策略。选择的方法则是通过以上介绍的四个步骤不停模拟得到每个选择的value。

其中,tree_policy函数实现了Selection和Expansion,default_poliy函数实现的是Simulation过程,backup函数是BackPropagation的实现。

def MCTS(node):

    computation_budget = 3

    for i in range(computation_budget):

        # 1\. 找到最合适的可扩展子节点 
        expand_node = tree_policy(node)

        # 2\. 随机选择下一步策略对此子节点进行模拟 
        reward = default_policy(expand_node)

        # 3\. 将模拟结果向上回传
        backup(expand_node, reward)

    # 最终得到胜利的可能性最大的子节点

     best_next_node = best_child(node, False)

     return best_next_node

tree_policy:选择最合适的子节点,选择策略如下:

1,如果当前的根节点是叶子节点,即没有子节点可以扩展,以开头下棋的例子来讲,即是已经判断出了胜负或者棋盘已满的情况下,则直接返回当前节点。

2,如果还有没有选择过的叶子节点(下一步的某个位置的着法还没有被模拟过),就在没有选择过的方法中选择一个返回。

3,如果所有可选择的结点都已经选择过(当前环境下所有的着法都已经试过),那么往下选择UCB值最大的子节点,直到满足1或2的情况,到达叶子节点或者出现未选择过的结点。

def tree_policy(node):

    # 是否是叶子节点
    while not node.get_state().is_terminal():

         # 如果全部可选的结点都选择过
         if node.is_all_expand():
             # 选择UCB最大的值
             node = best_child(node, True)

         else:

             # 随机选择一个节点返回
             sub_node = expand(node)
             return sub_node

    # 返回找到的最佳子节点
    return node

default_policy:对当前情况进行模拟,直到判断出胜负。

策略为:输入需要扩展的结点,随机操作后 创建新的结点,直到最后遇到叶子节点,得到该次模拟的reward,然后将reward返回。

def default_policy(node): 
        # 获取当前点的环境状态

        current_state = node.get_state() 

        # 如果没有遇到叶子节点,就一直循环
        while current_state.is_terminal() == False: 
                  # 随机选取一个子节点,返回新的环境参数 
                  current_state = current_state.get_next_state_with_random_choice()

        # 结束后,根据当前的环境判断胜负,即获得的reward值,并将其返回 
        final_state_reward = current_state.compute_reward()

        return final_state_reward

关于这个算法,我简单做了一个实现,每次从数组[1, -1, 2, -2]之间随机取一个数做累加,共累计MAX_DEPTH层,使最终的和最大,我们根据运行结果可以看到,开始-1, -2的概率比较大,但是随着训练层数的增大,越来越小,而1,2的比例会越来越大。

import sys
import math
import random
MAX_CHOICE = 4
MAX_DEPTH = 50
CHOICES = [1, -1, 2, -2]
class State(object):
def __init__(self):
self.value = 0
self.round = 0
self.choices = []
def new_state(self):
choice = random.choice(CHOICES)
state = State()
state.value = self.value + choice
state.round = self.round + 1
state.choices = self.choices + [choice]
return state
def __repr__(self):
return "State: {}, value: {}, choices: {}".format(
hash(self), self.value, self.choices)
class Node(object):
def __init__(self):
self.parent = None
self.children = []
self.quality = 0.0
self.visit = 0
self.state = None
def add_child(self, node):
self.children.append(node)
node.parent = self
def __repr__(self):
return "Node: {}, Q/N: {}/{}, state: {}".format(
hash(self), self.quality, self.visit, self.state)
def expand(node):
states = [nodes.state for nodes in node.children]
state = node.state.new_state()
while state in states:
state = node.state.new_state()
child_node = Node()
child_node.state = state
node.add_child(child_node)
return child_node
# 选择, 扩展
def tree_policy(node):
# 选择是否是叶子节点,
while node.state.round < MAX_DEPTH:
if len(node.children) < MAX_CHOICE:
node = expand(node)
return node
else:
node = best_child(node)
return node
# 模拟
def default_policy(node):
now_state = node.state
while now_state.round < MAX_DEPTH:
now_state = now_state.new_state()
return now_state.value
def backup(node, reward):
while node != None:
node.visit += 1
node.quality += reward
node = node.parent
def best_child(node):
best_score = -sys.maxsize
best = None
for sub_node in node.children:
C = 1 / math.sqrt(2.0)
left = sub_node.quality / sub_node.visit
right = 2.0 * math.log(node.visit) / sub_node.visit
score = left + C * math.sqrt(right)
if score > best_score:
best = sub_node
best_score = score
return best
def mcts(node):
times = 5
for i in range(times):
expand = tree_policy(node)
reward = default_policy(expand)
backup(expand, reward)
best = best_child(node)
return best
def main():
init_state = State()
init_node = Node()
init_node.state = init_state
current_node = init_node
for i in range(MAX_DEPTH):
a = 0.0
b = 0.0
c = 0.0
d = 0.0
current_node = mcts(current_node)
for j in range(len(current_node.state.choices)):
if current_node.state.choices[j] == -2:
a += 1
if current_node.state.choices[j] == -1:
b += 1
if current_node.state.choices[j] == 1:
c += 1
if current_node.state.choices[j] == 2:
d += 1
print("-2的概率为", round(a/(i + 1.0), 2),
"-1的概率为", round(b/(i + 1.0), 2),
"1的概率为", round(c/(i + 1.0), 2),
"2的概率为", round(d/(i + 1.0), 2))
if __name__ == "__main__":
main()

运行结果:

 

MCTS学习笔记「建议收藏」

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/148674.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • jadxgui反编译教程_apktool工具反编译apk

    jadxgui反编译教程_apktool工具反编译apk可以直接在GitHub上:https://github.com/skylot/jadx.git找到反编译工具jadx-gui源码,在windows电脑:(电脑上已经有git命令工具)gitclonehttps://github.com/skylot/jadx.git然后打开cmd命令窗口:进入到gitclone下来的文件所在的文件路径下,cdE:\jadx之后运行:gra…

    2022年10月25日
  • 贝叶斯分类器[通俗易懂]

    贝叶斯分类器[通俗易懂]实验名称:贝叶斯分类器一、实验目的和要求目的:掌握利用贝叶斯公式进行设计分类器的方法。要求:分别做出协方差相同和不同两种情况下的判别分类边界。二、实验环境、内容和方法环境:windows7,m

  • ActiveXObject 安装

    ActiveXObject 安装将后缀名为ocx的文件拷贝至目录c:\Windows\SysWOW64\。执行如下命令,进行注册:regsvr32c:\Windows\SysWOW64\x.ocx转载于:https://www.cnblogs.com/Currention/p/11024354.html

    2022年10月14日
  • rstudio安装后打不开_r语言和rstudio的安装

    rstudio安装后打不开_r语言和rstudio的安装1、r语言的下载地址TheComprehensiveRArchiveNetwork​cran.r-project.org2、安装:按着提醒直接下一步,路径不用改,默认的就可以,但是不可以出现中文,64位系统就全选,32位系统不能选64位系统,一个是32位一个是64位。3、下载RStudio安装包RStudio​www.rstudio.com免费的4、安装rstudio跟着提醒点击下一步,安…

  • java冒泡排序代码_Java冒泡排序

    java冒泡排序代码_Java冒泡排序一、冒泡排序:利用冒泡排序对数组进行排序二、基本概念:依次比较相邻的两个数,将小数放在前面,大数放在后面。即在第一趟:首先比较第1个和第2个数,将小数放前,大数放后。然后比较第2个数和第3个数,将小数放前,大数放后,如此继续,直至比较最后两个数,将小数放前,大数放后。至此第一趟结束,将最大的数放到了最后。在第二趟:仍从第一对数开始比较(因为可能由于第2个数和第3个数的交换,使得第1个数不再小于第2…

  • Tomcat日志切割总结[通俗易懂]

    Tomcat日志切割总结[通俗易懂]目录目录前言1.创建shell脚本进行catalina.out日志文件切割2.使用log4j成功使catalina.out文件实现分割3.用cronolog软件来分割Tomcat的catalina.out文件假设我们想日志以catalina.2018-08-31.out这种方式分割前言我们都知道将一个项目部署到Tomcat之后,Tomcat服…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号