87. 实现 Sarsa 学习算法走出迷宫#

87.1. 介绍#

在基于价值的强化学习中我们主要实现了 Q-Learning 算法,事实上 Sarsa 算法和 Q-Learning 最大的区别就在于 Q-Table 的更新,本次实验,结合实验中 Q-Learning 的算法实现,并根据 Sarsa 的算法流程来完成迷宫挑战。

87.2. 知识点#

  • Q-Table 初始化

  • Q-Table 更新函数

  • Sarsa 完整算法实现

87.3. Q-Table 初始化#

根据前面的实验内容,你应该知道不论是 Q-Learning 还是 Sarsa,其核心都是基于价值迭代,所以需要先初始化 Q-Table。

Exercise 87.1

挑战:按要求初始化 Q-Table。

规定:构造一个 \(16*4\) 的 DataFrame 表(16 个 state,4 个 action)作为 Q-Table。

提示:和实验中 Q-Learning 初始化方式相同。

import numpy as np
import pandas as pd
import time
from IPython import display


def init_q_table():
    ### 代码开始 ### (≈ 2 行代码)
    actions = None
    q_table = None
    ### 代码结束 ###

    return q_table
init_q_table()

Note

本课程中,Notebook 挑战系统无法自动评判,你需要自行补充上方单元格中缺失的代码并运行,如果输出结果和下方的期望输出结果一致,即代表此挑战顺利通过。完成全部内容后,点击「提交检测」即可通过,此说明后续不再出现。

期望输出

up down left right
0 0.0 0.0 0.0 0.0
1 0.0 0.0 0.0 0.0
2 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0
8 0.0 0.0 0.0 0.0
9 0.0 0.0 0.0 0.0
10 0.0 0.0 0.0 0.0
11 0.0 0.0 0.0 0.0
12 0.0 0.0 0.0 0.0
13 0.0 0.0 0.0 0.0
14 0.0 0.0 0.0 0.0
15 0.0 0.0 0.0 0.0

87.4. 动作选择#

接下来,我们需要使用 \(\epsilon-greedy\) 方法根据 Q-Table 进行动作选择,这里仿照实验内容实现 act_choose 函数。

Exercise 87.2

挑战:使用 \(\epsilon-greedy\) 方法根据 Q-Table 进行动作选择。

规定:在概率为 \(1-epsilon\) ,或 Q 值都为 0 的情况下,随机选择动作;此外,按照 Q 的最大值选择动作,并且动作用 action 表示。

提示:这里可能会使用 if,else 语句判断,与实验中内容相同。

def act_choose(state, q_table, epsilon):
    state_act = q_table.iloc[state, :]
    actions = np.array(["up", "down", "left", "right"])

    ### 代码开始 ### (≈ 4 行代码)
    if None:
        action = None
    else:
        action = None
    ### 代码结束 ###

    return action

运行测试

seed = np.random.RandomState(25)  # 为了保证验证结果相同引入随机数种子
a = seed.rand(16, 4)
test_q_table = pd.DataFrame(a, columns=["up", "down", "left", "right"])
l = []
for s in [1, 4, 7, 12, 14]:
    l.append(act_choose(state=s, q_table=test_q_table, epsilon=1))
l

期望输出

['left', 'right', 'right', 'right', 'left']

87.5. 行为反馈#

在行为反馈中我们同样将 terminal 终点的奖励设为 10,将 hole 陷阱的惩罚设为 -10,同样为了尽快找到最短路径,每一步的惩罚为 -1。直接沿用实验中相似代码块即可。

def env_feedback(state, action, hole, terminal):
    reward = 0.0
    end = 0
    a, b = state
    if action == "up":
        a -= 1
        if a < 0:
            a = 0
        next_state = (a, b)
    elif action == "down":
        a += 1
        if a >= 4:
            a = 3
        next_state = (a, b)
    elif action == "left":
        b -= 1
        if b < 0:
            b = 0
        next_state = (a, b)
    elif action == "right":
        b += 1
        if b >= 4:
            b = 3
        next_state = (a, b)

    if next_state == terminal:
        reward = 10.0
        end = 2
    elif next_state == hole:
        reward = -10.0
        end = 1
    else:
        reward = -1.0

    return next_state, reward, end

87.6. Q-Table 更新#

接下来,就需要完成 Q-Table 更新函数。通过实验内容可知,Sarsa 的 Q-Table 的更新是与 Q-Learning 最大的区别之处,所以需要根据 Sarsa 的 Q-Table 更新公式来实现。

Exercise 87.3

挑战:根据下方 Sarsa 的 Q-Table 的更新公式完善 Q-Table 更新函数。

提示:结合 Q-Learning 中 Q-Table 更新函数进行修改,通过标签查看 DataFrame 特定值时使用 .loc[]

\[ Q(s_{t},a_{t})=(1-\alpha) \cdot Q(s_{t},a_{t})+\alpha \cdot (r_{t}+\gamma \cdot Q(s_{t+1},a_{t+1})) \]
def update_q_table(
    q_table, state, action, next_state, next_action, terminal, gamma, alpha, reward
):
    x, y = state
    next_x, next_y = next_state
    q_original = q_table.loc[x * 4 + y, action]

    if next_state != terminal:
        ### 代码开始 ### (≈ 1 行代码)
        q_predict = None
        ### 代码结束 ###
    else:
        q_predict = reward

    ### 代码开始 ### (≈ 1 行代码)
    q_table.loc[None] = None
    ### 代码结束 ###

    return q_table

运行测试(仅执行一次,重复执行请重启 kernel)

new_q_table = update_q_table(
    q_table=test_q_table,
    state=(2, 2),
    action="right",
    next_state=(2, 3),
    next_action="down",
    terminal=(3, 2),
    gamma=0.9,
    alpha=0.8,
    reward=10,
)

new_q_table.loc[10, "right"]

期望输出:(仅执行一次得到的结果)

8.740755431411795

同样为了展示强化学习效果,定义一个状态展示函数,此处综合沿用实验中相应代码块即可。

def show_state(end, state, episode, step, q_table):
    terminal = (3, 2)
    hole = (2, 1)
    env = np.array([["_ "] * 4] * 4)
    env[terminal] = "$ "
    env[hole] = "# "
    env[state] = "L "
    interaction = ""
    for i in env:
        interaction += "".join(i) + "\n"

    if state == terminal:
        message = "EPISODE: {}, STEP: {}".format(episode, step)
        interaction += message
        display.clear_output(wait=True)
        print(interaction)
        print("\n" + "q_table:")
        print(q_table)
        time.sleep(3)  # 在成功到终点时,等待 3 秒
    else:
        display.clear_output(wait=True)
        print(interaction)
        print("\n" + "q_table:")
        print(q_table)
        time.sleep(0.3)  # 在这里控制每走一步所需要时间

87.7. Sarsa 算法实现#

最后,我们根据 Sarsa 算法伪代码来实现完整的学习过程。

Exercise 87.4

挑战:顺利完成以上几个函数后,根据下方 Sarsa 算法伪代码实现完整的学习过程。请结合 Q-Learning 完成代码。

image

def sarsa(max_episodes, alpha, gamma, epsilon):
    q_table = init_q_table()
    terminal = (3, 2)
    hole = (2, 1)
    episodes = 0
    while episodes < max_episodes:
        step = 0
        state = (0, 0)
        end = 0
        show_state(end, state, episodes, step, q_table)
        x, y = state

        ### 代码开始 ### (≈ 1 行代码)
        action = None  # 动作选择
        ### 代码结束 ###

        while end == 0:
            next_state, reward, end = env_feedback(
                state, action, hole, terminal
            )  # 环境反馈
            next_x, next_y = next_state
            next_action = act_choose(next_x * 4 + next_y, q_table, epsilon)  # 动作选择

            ### 代码开始 ### (≈ 3 行代码)
            q_table = None  # q-table 更新
            state = None
            action = None
            ### 代码结束 ###

            step += 1
            show_state(end, state, episodes, step, q_table)
        if end == 2:
            episodes += 1
sarsa(max_episodes=20, alpha=0.8, gamma=0.9, epsilon=0.9)  # 执行测试

期望输出

_ _ _ _ 
_ _ _ _ 
_ # _ _ 
_ _ L _ 
EPISODE: 19, STEP: 5

q_table:
          up       down      left     right
0  -4.421534  -3.457078 -3.936450 -4.152483
1  -3.409185  -9.062400 -3.433181 -3.596441
2  -2.213120   4.590499 -3.514029 -3.414613
3  -1.536000  -1.574400 -2.908418 -3.936450
4  -4.109114  -2.730086 -2.836070 -2.548000
5  -2.836070  -9.984000 -2.065920 -1.720000
6  -2.850867   8.000000 -2.562662 -0.800000
7  -2.342144  -0.800000 -1.982720  0.000000
8  -2.509531  -2.348544 -2.213120 -8.000000
9   0.000000   0.000000  0.000000  0.000000
10 -3.033926  10.000000 -8.000000 -0.800000
11 -2.844488   0.000000  0.000000 -0.800000
12 -2.766862  -1.536000 -1.536000  6.142464
13  0.000000   0.000000 -2.325504  9.600000
14  0.000000   0.000000  0.000000  0.000000
15  0.000000   0.000000  0.000000  0.000000

由于 Q Table 的值是随机的,上面的实验结果仅供参考。只要随着迭代次数的增加,Q Table 按要求持续更新,并使得智能体走的步数变少,最终接近 5 步即可。