强化学习笔记 (2)- 从 Q-Learning 到 DQN
在上一篇文章强化学习笔记 (1)- 概述中,介绍了通过 MDP 对强化学习的问题进行建模,但是由于强化学习往往不能获取 MDP 中的转移概率,解决 MDP 的 value iteration 和 policy iteration 不能直接应用到解决强化学习的问题上,因此出现了一些近似的算法来解决这个问题,本文要介绍的就是基于 value iteration 而发展出来的 Q-Learning 系列方法,包括 Q-Learning、Sarsa 和 DQN。
Q-Learning
Q-Learning 是一个强化学习中一个很经典的算法,其出发点很简单,就是用一张表存储在各个状态下执行各种动作能够带来的 reward,如下表表示了有两个状态 \(s_1, s_2\),每个状态下有两个动作 \(a_1, a_2\), 表格里面的值表示 reward
- | a1 | a2 |
---|---|---|
s1 | -1 | 2 |
s2 | -5 | 2 |
这个表示实际上就叫做 Q-Table,里面的每个值定义为 \(Q(s, a)\), 表示在状态 \(s\) 下执行动作 \(a\) 所获取的 reward,那么选择的时候可以采用一个贪婪的做法,即选择价值最大的那个动作去执行。
这样问题就来了,就是 Q-Table 要如何获取?答案是随机初始化,然后通过不断执行动作获取环境的反馈并通过算法更新 Q-Table。下面重点讲如何通过算法更新 Q-Table。
当我们处于某个状态 \(s\) 时,根据 Q-Table 的值选择的动作 \(a\), 那么从表格获取的 reward 为 \(Q(s,a)\),此时的 reward 并不是我们真正的获取的 reward,而是预期获取的 reward,那么真正的 reward 在哪?我们知道执行了动作 \(a\) 并转移到了下一个状态 \(s'\) 时,能够获取一个即时的 reward(记为 \(r\)), 但是除了即时的 reward,还要考虑所转移到的状态 \(s'\) 对未来期望的 reward,因此真实的 reward (记为 \(Q'(s,a)\)) 由两部分组成:即时的 reward 和未来期望的 reward,且未来的 reward 往往是不确定的,因此需要加个折扣因子 \(\gamma\), 则真实的 reward 表示如下
\[\begin{align} Q'(s,a) = r + \gamma\max_{a'}Q(s',a') \end{align}\]
\(\gamma\) 的值一般设置为 0 到 1 之间,设为 0 时表示只关心即时回报,设为 1 时表示未来的期望回报跟即时回报一样重要。
有了真实的 reward 和预期获取的 reward,可以很自然地想到用 supervised learning 那一套,求两者的误差然后进行更新,在 Q-learning 中也是这么干的,更新的值则是原来的 Q (s, a),更新规则如下
\[\begin{align} Q(s, a) = Q(s, a) + \alpha(Q'(s, a) - Q(s,a)) \end{align}\]
更新规则跟梯度下降非常相似,这里的 \(\alpha\) 可理解为学习率。
更新规则也很简单,可是这一类采用了贪心思想的算法往往都会有这么一个问题:算法是否能够收敛,是收敛到局部最优还是全局最优?
关于收敛性,可以参考 Convergence of Q-learning: a simple proof,这个文档 证明了这个算法能够收敛,且根据知乎上这个问题 RL 两大类算法的本质区别?(Policy Gradient 和 Q-Learning),原始的 Q-Learning 理论上能够收敛到最优解,但是通过 Q 函数近似 Q-Table 的方法则未必能够收敛到最优解(如 DQN)。
除此之外, Q-Learning 中还存在着探索与利用 (Exploration and Exploition) 的问题, 大致的意思就是不要每次都遵循着当前看起来是最好的方案,而是会选择一些当前看起来不是最优的策略,这样也许会更快探索出更优的策略。
Exploration and Exploition 的做法很多,Q-Learning 采用了最简单的 \(\epsilon\)-greedy, 就是每次有 \(\epsilon\) 的概率是选择当前 Q-Table 里面值最大的 action 的,1 - \(\epsilon\) 的概率是随机选择策略的。
Q-Learning 算法的流程如下,图片摘自这里
上面的流程中的 Q 现实 就是上面说的 \(Q'(s,a)\), Q 估计就是上面说的 \(Q(s,a)\)。
下面的 python 代码演示了更新通过 Q-Table 的算法,参考了这个 repo 上的代码,初始化主要是设定一些参数,并建立 Q-Table, choose_action
是根据当前的状态 observation
,并以 \(\epsilon\)-greedy 的策略选择当前的动作; learn
则是更新当前的 Q-Table,check_state_exist
则是检查当前的状态是否已经存在 Q-Table 中,若不存在要在 Q-Table 中创建相应的行。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43import numpy as np
import pandas as pd
class QTable:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = actions # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def choose_action(self, observation):
self.check_state_exist(observation)
# action selection
if np.random.uniform() < self.epsilon:
# choose best action
state_action = self.q_table.ix[observation, :]
state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value
action = state_action.argmax()
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
Sarsa
Sarsa 跟 Q-Learning 非常相似,也是基于 Q-Table 进行决策的。不同点在于决定下一状态所执行的动作的策略,Q-Learning 在当前状态更新 Q-Table 时会用到下一状态 Q 值最大的那个动作,但是下一状态未必就会选择那个动作;但是 Sarsa 会在当前状态先决定下一状态要执行的动作,并且用下一状态要执行的动作的 Q 值来更新当前状态的 Q 值;说的好像很绕,但是看一下下面的流程便可知道这两者的具体差异了,图片摘自这里
那么,这两者的区别在哪里呢?这篇文章里面是这样讲的
This means that SARSA takes into account the control policy by which the agent is moving, and incorporates that into its update of action values, where Q-learning simply assumes that an optimal policy is being followed.
简单来说就是 Sarsa 在执行 action 时会考虑到全局(如更新当前的 Q 值时会先确定下一步要走的动作), 而 Q-Learning 则显得更加的贪婪和 "短视", 每次都会选择当前利益最大的动作 (不考虑 \(\epsilon\)-greedy),而不考虑其他状态。
那么该如何选择,根据这个问题:When to choose SARSA vs. Q Learning,有如下结论
If your goal is to train an optimal agent in simulation, or in a low-cost and fast-iterating environment, then Q-learning is a good choice, due to the first point (learning optimal policy directly). If your agent learns online, and you care about rewards gained whilst learning, then SARSA may be a better choice.
简单来说就是如果要在线学习,同时兼顾 reward 和总体的策略 (如不能太激进,agent 不能很快挂掉),那么选择 Sarsa;而如果没有在线的需求的话,可以通过 Q-Learning 线下模拟找到最好的 agent。所以也称 Sarsa 为 on-policy,Q-Leanring 为 off-policy。
DQN
我们前面提到的两种方法都以依赖于 Q-Table,但是其中存在的一个问题就是当 Q-Table 中的状态比较多,可能会导致整个 Q-Table 无法装下内存。因此,DQN 被提了出来,DQN 全称是 Deep Q Network,Deep 指的是通的是深度学习,其实就是通过神经网络来拟合整张 Q-Table。
DQN 能够解决状态无限,动作有限的问题;具体来说就是将当前状态作为输入,输出的是各个动作的 Q 值。以 Flappy Bird 这个游戏为例,输入的状态近乎是无限的(当前 bird 的位置和周围的水管的分布位置等),但是输出的动作只有两个 (飞或者不飞)。实际上,已经有人通过 DQN 来玩这个游戏了,具体可参考这个 DeepLearningFlappyBird
所以在 DQN 中的核心问题在于如何训练整个神经网络,其实训练算法跟 Q-Learning 的训练算法非常相似,需要利用 Q 估计和 Q 现实的差值,然后进行反向传播。
这里放上提出 DQN 的原始论文 Playing atari with deep reinforcement learning 中的算法流程图
上面的算法跟 Q-Learning 最大的不同就是多了 Experience Replay 这个部分,实际上这个机制做的事情就是先进行反复的实验,并将这些实验步骤获取的 sample 存储在 memory 中,每一步就是一个 sample,每个 sample 是一个四元组,包括:当前的状态,当前状态的各种 action 的 Q 值,当前采取的 action 获得的即时回报,下一个状态的各种 action 的 Q 值。拿到这样一个 sample 后,就可以根据上面提到的 Q-Learning 更新算法来更新网络,只是这时候需要进行的是反向传播。
Experience Replay 机制的出发点是按照时间顺序所构造的样本之间是有关的 (如上面的 \(\phi(s_{t+1})\) 会受到 \(\phi(s_{t})\) 的影响)、非静态的(highly correlated and non-stationary),这样会很容易导致训练的结果难以收敛。通过 Experience Replay 机制对存储下来的样本进行随机采样,在一定程度上能够去除这种相关性,进而更容易收敛。当然,这种方法也有弊端,就是训练的时候是 offline 的形式,无法做到 online 的形式。
除此之外,上面算法流程图中的 aciton-value function 就是一个深度神经网络,因为神经网络是被证明有万有逼近的能力的,也就是能够拟合任意一个函数;一个 episode 相当于 一个 epoch;同时也采用了 \(\epsilon\)-greedy 策略。代码实现可参考上面 FlappyBird 的 DQN 实现。
上面提到的 DQN 是最原始的的网络,后面 Deepmind 对其进行了多种改进,比如说 Nature DQN 增加了一种新机制 separate Target Network,就是计算上图的 \(y_j\) 的时候不采用网络 \(Q\), 而是采用另外一个网络 (也就是 Target Network) \(Q'\), 原因是上面计算 \(y_j\) 和 Q 估计都采用相同的网络 \(Q\),这样使得 Q 大的样本,y 也会大,这样模型震荡和发散可能性变大,其原因其实还是两者的关联性较大。而采用另外一个独立的网络使得训练震荡发散可能性降低,更加稳定。一般 \(Q'\) 会直接采用旧的 \(Q\), 比如说 10 个 epoch 前的 \(Q\).
除此之外,大幅度提升 DQN 玩 Atari 性能的主要就是 Double DQN,Prioritised Replay 还有 Dueling Network 三大方法;这里不详细展开,有兴趣可参考这两篇文章:DQN 从入门到放弃 6 DQN 的各种改进 和 深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction。
综上,本文介绍了强化学习中基于 value 的方法:包括 Q-Learning 以及跟 Q-Learning 非常相似的 Sarsa,同时介绍了通过 DQN 解决状态无限导致 Q-Table 过大的问题。需要注意的是 DQN 只能解决动作有限的问题,对于动作无限或者说动作取值为连续值的情况,需要依赖于 policy gradient 这一类算法,而这一类算法也是目前更为推崇的算法,在下一章将介绍 Policy Gradient 以及结合 Policy Gradient 和 Q-Learning 的 Actor-Critic 方法。
参考
- Q Learning
- Sarsa
- When to choose SARSA vs. Q Learning
- DQN 从入门到放弃 5 深度解读 DQN 算法
- 深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction