用于决策的世界模型 — 论文 World Models (2018) & PlaNet (2019) 讲解

参考资料:

世界模型

简介

世界模型:一种理解世界当前状态预测其未来动态的工具。

世界模型的两个主要功能

  1. 构建内部表征以理解世界运作机制。
  2. 预测未来状态以模拟和指导决策。

分类

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

图片来自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.

  • 作者按照模型的侧重点不同,将世界模型分成两个大类,即:
    • Internal Representations.
    • Future Predictions.
  • 经常能在网上刷到的LeCun力推世界模型,说的是JEPA.
  • 左边分支的世界模型也可以做"future prediction",作为学习模型参数过程的一个副产物吧 (视觉模块的reconstruction)。

这里讨论的是两篇world model for decision-making的文章。

World Models (2018)

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

AI社区中,首篇系统性介绍世界模型的文章。

人类的心理模型

简单可以概括成以下几点:

  • 对于外部世界的大量信息流,人脑能够学习到外部世界时空信息的抽象表示,作为我们对外部世界的"建模"。
  • 我们所看到一切都基于脑中模型对未来的预测。

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

  • 我们能够基于这个预测模型本能地行动,在面对危险时做出快速的反射性行为。

打棒球的例子: 击球手需要在毫秒级别的时间内决定如何挥棒 —— 这比视觉信号到大脑的时间还要短。

在之后的世界模型结构和实验中,都可以看到这个心理模型的影子。

模型结构

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

世界模型主要由两个模块组成:视觉模块、记忆模块。
1. 视觉模块:将外部世界的高维观测,压缩成低维的特征。
2. 记忆模块:整合历史信息,预测未来。

控制器会利用世界模型给出的信息进行决策。

视觉模块

作者在文章中使用VAE的Encoder部分作为视觉模块。
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

记忆模块

作者在文章中使用MDN-RNN作为记忆模块。
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

  • MDN指的是mixture density networks,就是一个建模混合模型的网络,文中使用的是高斯混合模型 (GMM),此时神经网络除了输出每个高斯分布的均值和标准差,还需要输出用于选择高斯分布的类别分布。
  • MDN会接受一个temperature参数(tau),用于调整不确定性。
  • 在图中,MDN-RNN建模的是(P(z_{t+1}mid a_t, z_t, h_t)).
  • 除了隐状态之外,记忆模块可能还需要建模其他东西,比如奖励(P(r_{t+1} mid a_t, z_t, h_t)),游戏结束的信号(P(text{done}_{t+1} mid a_t, z_t, h_t)).

NOTE:为什么要使用混合模型,即使VAE的隐变量空间只是一个对角高斯?作者的解释是:混合模型中的离散部分 (选择哪一个高斯组分),有利于建模环境中的离散随机事件。比如说NPC在平静状态和警觉状态下的表现不同。

控制器

作者将整个模型的复杂性都集中到了视觉和记忆模块,有意使得控制器的结构尽可能简单:

[a_t = W_c[z_t~~h_t] + b_c ]

就是单层的神经网络。

模型训练和实验

文章官网World Models,有gif演示,而且可以试玩模型"梦中"的游戏。

训练

两个实验都是先单独训练世界模型 (无监督):

  1. 使用随机策略收集一系列的游戏图像。
  2. 使用这些图像训练好VAE。
  3. 在训练好的VAE基础上,训练好MDN-RNN。

之后部署世界模型并训练控制器。两个实验的主要区别在部署:

  • Car Racing实验:直接在实际环境部署,训练好了控制器之后,又给出了在模型"梦中"的模拟。用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

  • VizDoom实验:先在"梦中"部署,训练好了控制器之后,再将整个模型转移到实际环境查看效果。用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

NOTE:在两个实验中,世界模型都没有建模环境的奖励。第一个实验中,奖励只在训练控制器的时候由实际环境给出;第二个实验中,指标是存活时间,不需要奖励。

REMARK:训练成功之后,模型实际上成为了游戏的"模拟器",学习到了游戏逻辑 (角色中弹后会重新开始)、敌人行为 (按一定时间间隔发射子弹)、物理机制 (子弹飞行速度)等。

实验

Car Racing:
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

VizDoom:
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

消融实验1 -- 视觉模块+记忆模块的优越性
在Car Racing中,消融实验显示,单独的视觉模块效果不如一整个的世界模型 (但是也已经超过了DQN和A3C)

消融实验2 -- 用tau调整随机性
在VizDoom实验中,由于模型并非完全精确,控制器可能会利用模型的缺陷来在模拟器中达到高分,一旦部署到实际环境,控制器就不行了。

为了防止这一点,MDN-RNN预测的是具有随机性的环境,并通过调整不确定性参数(tau)来控制随机性。在实验中,(tau=1.15)时效果最好。

(tau=0.1)时,模型几乎是确定性的,这时候敌人甚至无法发射子弹,所以出现了在模拟器中非常高分,实际环境中却非常低分的情况。

跑分对比实验

  • Car Racing实验:取得的分数超过了先前的基于深度强化学习的方法,如DQN、A3C.
  • VizDoom实验:在梦中学会了如何躲避怪物的子弹,部署到实际环境后的存活时长也超过了先前。

迭代训练过程

本文的实验环境简单,所以是使用随机策略采样,分别训练三个模块。面对更复杂的任务,可能需要三个模块一起训练,但是本文只是提了一下记忆模块和控制器一起训练的流程:
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

三个模块一起训练的好处是:

  1. 视觉模块会倾向于学习到有利于当前任务的特征。
  2. 记忆模块可以对控制器进行学习,控制器又可以基于记忆模块继续改进,如此往复。
  3. 可以使用训练中的控制器进行轨迹采样而不是随机策略。

Learning Latent Dynamics for Planning from Pixels (2019)

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

相对于上一篇,这篇的改进:

  1. 假定了环境是部分可观测马尔可夫决策过程 (POMDP),世界模型就是在学习这个POMDP.
  2. 给出了一套结合模型预测控制 (MPC) 方法的训练过程 —— Deep Planning Network (PlaNet).
  3. 提出基于确定性和随机性结合的状态空间模型 (RSSM),而不是仅有确定性状态的RNN和仅有随机性状态的SSM.
  4. 给出了适用于多步预测的变分推断方法 —— latent overshooting.

Problem setup

假定实际的环境是一个POMDP:
用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

目标是学习到一个策略,能够最大化期望累积回报(mathbb E[sum r_t])

Deep planning network

这里先讲世界模型+MPC的学习和规划算法。

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

while循环内部,总体上分成三个部分:模型学习,实时规划+数据收集,更新数据库。

模型学习

从数据库中随机抽取观测序列的小批量,然后使用梯度方法学习。

实时规划+数据收集

总体上就是一个有限时间域的MPC框架,在每个time step按三步走:

  1. Observe:获得当前时刻的状态。由于这里在隐状态空间进行规划,所以需要从历史的观测数据中推断当前状态 (通过隐变量的后验概率)。
  2. Predict and plan:利用当前学习到的模型,解一个有限时间域的最优控制问题,获得一串动作序列。本文中的planner使用的是cross entropy method (CEM).
  3. Act:对环境使用这串动作序列的第一个动作(a_t),移动到下一个time step. 这里用了一个trick,把取得的动作(a_t)重复了(R)次 (用相同的action,连续走了(R)步),取reward的总和作为当前时刻的reward,取最终的第(R)观测(o_{t+1}^R)作为下一个时刻的观测(o_{t+1})

更新数据库

将上一个部分收集到的观测序列加入到数据库中,以供世界模型的进一步更新。

NOTE:相对于model-free RL算法,model-based planning的一大优势就是数据利用率提高了。体现在planning取得的观测序列可以反复用于世界模型的学习。

RSSM

这种模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相关的文章,好像要从头到尾讲明白 (像VAE那样) 比较复杂。

这里浅浅讲一下世界模型的结构以及训练的Loss。

Latent state-space model

用于决策的世界模型 -- 论文 World Models (2018) & PlaNet (2019) 讲解

使用下面的encoder来近似后验概率:

[q(s_{le t} mid o_{le t},a_{<t}) = prod_{t=1}^T q(s_tmid s_{t-1},a_{t-1},o_t) ]

都使用神经网络参数化的高斯分布表示,其中observation model和encoder用的是卷积网络。

Training Objective

通过最大化log Evidence来训练:

[argmax ln p(o_{le t} mid a_{<t}) ]

接下来推导ELBO.

先拆成边际化的形式

[ln p(o_{le T} mid a_{<T}) = ln int p(o_{le T}, s_{le T} mid a_{<T}) text{d}s ]

把联合概率拆开

[ln p(o_{le T} mid a_{<T}) = ln int p(o_{le T} mid s_{le T}, a_{<T}) p(s_{le T} mid a_{<T}) text{d}s ]

写成期望的形式

[ln p(o_{le t} mid a_{<t}) = ln mathbb E_{p(s_{le t}mid a_{<t})}[ p(o_{le t} mid s_{le t}, a_{<t}) ] ]

利用重要性采样方法,转变成从encoder采样

[ln p(o_{le t} mid a_{<t}) = ln mathbb E_{q(s_{le t}mid o_{le t},a_{<t})}[ p(o_{le t} mid s_{le t}, a_{<t}) p(s_{le t}mid a_{<t}) / q(s_{le t}mid o_{le t},a_{<t})] ]

链式分解,并利用模型的条件独立性化简 (概率图参考下面的)

[ln p(o_{le t} mid a_{<t}) = ln mathbb E_{q(s_{le t}mid o_{le t},a_{<t})}[prod p(o_t mid s_t) p(s_t mid s_{t-1},a_{t-1}) / q(s_{le t}mid o_{le t},a_{<t})] ]

根据Jensen不等式,(ln mathbb E[x] ge mathbb E[ln(x)])

[ln p(o_{le t} mid a_{<t}) ge mathbb E_{q(s_{le t}mid o_{le t},a_{<t})}[sum_t ln p(o_t mid s_t) + ln p(s_t mid s_{t-1},a_{t-1}) - ln q(s_{le t}mid o_{le t},a_{<t})] ]

右边可以写成reconstruction + KL的形式,最后就是
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

确定性和随机性结合 - RSSM

用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

  • 纯确定性的世界模型:模型难以预测多种可能的未来情况;容易被planner利用模型缺陷 (在World Models中,通过MDN添加随机性来缓解这一点,但本质还是确定性的)
  • 纯随机性的世界模型:模型难以记住信息,导致产生前后不一致的预测结果。

所以作者考虑将确定性和随机性结合,称这种结构为RSSM.

相对于上一篇,把记忆模块换成了RSSM。

Latent Overshooting

之前讨论的都是(s_t to s_{t+1})的单步预测,如果每次单步预测都准确无误,那多步预测肯定也没问题。但是由于模型本身有局限,所以不一定能很好的推广到多部预测。

于是作者考虑了直接进行跨步的预测,先通过对中间几步隐变量边际化得到了跨步预测的转移
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

并且推导了针对跨步预测的变分bound
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

把考虑不同的步幅(d),求和,就得到latent overshooting的目标函数
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

实验结果

用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

DeepMind control suite环境:图像作为观测,连续动作空间。

消融实验

  • 验证PlaNet的数据收集过程有优势。Random Collection指的是用随机策略收集数据而不是通过MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接从1000条随机采的动作序列里选最好的那条。最后PlaNet在大部分情况都明显好于另外两种。
    用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

  • RSSM和SSM、GRU的对比。观察到RSSM明显好于后两者,表明了确定性+随机性结合的优势。
    用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

  • 是否加入latent overshooting作为变分目标。观察到Latent overshooting使RSSM的表现轻微变差,但是在一些任务上让DRNN的表现变好了。
    用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

跑分对比实验
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

  • PlaNet的分数能打败A3C。
  • PlaNet的分数总体不如D4PG,但是大部分任务相差不多。
  • PlaNet在所有任务上,数据利用率都好于D4PG.
  • PlaNet (CEM + 世界模型) 和 CEM + true simulator对比只差了一些,体现出世界模型较好地学习到了环境。

六个任务一起训练
每次循环中,agent面对的可能是不同的环境,所以数据库中抽取出来的轨迹也是打乱的。
用于决策的世界模型 -- 论文 World Models (2018) &amp; PlaNet (2019) 讲解

最后跑分不如单独训练,但是体现出了agent能够自己判断出面对的是哪个任务了。

代码选讲

代码来自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics

主要是看看transition model和模型训练过程。解释都在注释里,有部分注释是代码库原有的。

Transition model

class TransitionModel(jit.ScriptModule):   __constants__ = ['min_std_dev']    def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):     super().__init__()     self.act_fn = getattr(F, activation_function)     self.min_std_dev = min_std_dev     self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)     self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1     self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t     self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std     self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t     self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std    # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations   # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):   # t :  0  1  2  3  4  5   # o :    -X--X--X--X--X-  设置了初始的隐状态是None,所以不考虑0时刻的obs   # a : -X--X--X--X--X-     不考虑最后一个action,因为最后一个action没有后续的obs   # n : -X--X--X--X--X-   # pb: -X-   # ps: -X-   # b : -x--X--X--X--X--X-   # s : -x--X--X--X--X--X-    # 输入的shape都是(time_step, batch_size, *)   @jit.script_method   def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:     # 后面都是动态更新,为了保留grad,不能使用单个tensor作为buffer,所以创建了几个list     T = actions.size(0) + 1 # 实际需要的list长度,参考上面的图     beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =        [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T     beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0时刻赋初值      # 每次循环开始,是已知t时刻的信息,进一步计算t+1时刻的信息     for t in range(T - 1):       # 根据情况合适的s,因为模型可以在脱离observations的情况下自己预测       # 如果observations为None,则使用先验状态 (模型一步步生成出来的),否则使用后验状态 (根据历史的obs和action推断出来的)       _state = prior_states[t] if observations is None else posterior_states[t]        # terminal则说明这段序列已经结束了,所以把状态mask掉 (就是0)       _state = _state if nonterminals is None else _state * nonterminals[t]          # 注意下面每一块的hidden是临时变量,表示的是不同的意思        # 计算确定性隐状态h = f(s_t, a_t, h_t)       hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起       beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 对应概率图中从s,a,h到h的实线        # 计算隐状态s的先验 p(s_t|s_t-1,a_t-1)       hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 对应概率图中从h到s的实线       prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)       prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus来保证std_devs为正,并且使用min_std_dev来保证std_devs不会太小       prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])             # 计算隐状态s的后验 q(s_t|o≤t,a<t)       if observations is not None: # 只有observations不为None时,才计算后验         t_ = t - 1  # 这是实现的问题,因为传进来的是obs[1:],所以应该用t_+1才能索引到对应的obs         hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 对应概率图中的两条虚线         posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)         posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev         posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])      # 返回h,s,以及先验和后验的均值和方差     hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]     if observations is not None:       hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]     return hidden 

世界模型训练
只截取了一小部分,重点看loss func是如何计算的。

  # Model fitting   losses = []   for s in tqdm(range(args.collect_interval)):     # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)     observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0      # Create initial belief and state for time t = 0     init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)      # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)     # 一次把整个隐状态序列全部计算出来     beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =       transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])      # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)     # Reconstruction loss都使用MSE     # mean(dim=(0, 1))对batch和time进行平均     observation_loss =       F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))     reward_loss =       F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))     # KL loss, 计算了后验q(s_t|o≤t,a<t)和先验p(s_t|s_t-1,a_t-1)的KL散度     kl_loss =       torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out  """ 后面的部分略 """ 

发表评论

相关文章