Exploring Transformer Model for Reinforcement Learning

MLP is widely used in RL to implement a learnable agent in a certain environment trained according to a specific algorithm. Recent works in NLP have already proved that Transformer can replace and outperform MLP in most tasks leading to expanding its utilization in areas outside of NLP such as Computer Vision. However, in RL the Transformer architecture is still not widely adopted, and agents are still implemented as simple MLP, or RNN to boost their memory.

Despite its outstanding performance, Transformer is known to be hard to train in general and requires a large amount of data and computation to converge. Typically, some tricks are required to stabilize the training of Transformer or improve its performance such as a complex learning rate schedule (e.g., linear warmup or cosine decay), gradient clipping, dropout, or specialized weight initialization schemes.

In addition to the not easy task to train a Transformer under the much simpler supervised training settings, RL features a number of challenges that supervised learning does not, including long-term temporal dependencies, sparse rewards, as well as balancing exploration and exploitation.

High-level architecture of Transformer

In 2018 Mishra et al. performed some initial experiments regarding the feasibility of using Transformer in RL on elementary environments such as simple bandit tasks and tabular Markov Decision Processes (MDPs).

However, their conclusion was that Transformed failed to learn in both environments, often resulting in performance comparable to a random policy, thus leading the authors to hypothesize that the transformer architecture was not suitable for processing sequential information in the RL context.

Some of the reasons that could hamper training of policies based on Transformer include:

  • Sample efficiency: in actor-critic algorithms, actor and critic are alternatively trained. This requires that their respective models need to adapt quickly to each other changes in only a few samples. However, Transformer is known to require a huge number of samples to learn, thus lacking sample efficiency.
  • Implicit sequential information: RL involves learning a policy that can solve an MDP environment. This means that sequential information plays a fundamental role in predicting the best next action. It is usually not a problem in typical RL scenarios since agents receive single environment observations one by one according to their natural temporal sequence. Nevertheless, Transformer receives, at each step, a sequence of observations, and their order is typically given by a positional encoding. Hence, sequential information is not explicitly given but has to be learned by Transformer.

One of the first successful cases: GTrXL

It was not until 2019 with the work of Parisotto et al. that Transformer had remarkable breakthroughs in RL introducing a new architecture known as GTrXL (Gated Transformer-XL).

Before diving into the details, let’s briefly introduce Transformer-XL and its advantage over its vanilla counterpart. Transformer-XL presents a particular architecture that enables learning dependency beyond a fixed length without disrupting temporal coherence. This means that attention-XL can take advantage of both the current input trajectory plus past trajectories to make predictions.

A novel relative positional encoding scheme is also introduced to preserve temporal information with attention-XL. This method not only enables capturing longer-term dependency but also resolves the context fragmentation problem.

To make Transformer achieve those superior performances, they mainly applied three architectural changes over the vanilla version:

  • Identity Map Reordering (IMR): this architectural variation consists in placing the layer normalization on only the input stream of the sub-modules (Multi-Head Attention and Position-Wise MLP), and adding residual connections. A key benefit to this reordering is that it now enables an identity map from the input of the transformer at the first layer to the output of the transformer after the last layer. If during the early stages of training submodules output values are in expectation near zero, the state encoding is passed un-transformed through the residual connections to the policy and value heads, enabling the agent to learn a Markovian policy at the start of training ignoring the contribution of past observations coming from the memory of the attention-XL. In many environments, reactive behaviors need to be learned before memory-based ones can be effectively utilized. For example, an agent needs to learn how to walk before it can learn how to remember where it has walked.
  • Gating Transformer (GT): this tweak replaces the residual connections with gating layers (GRU). This allows the model to learn to control the flow of information coming from the residual connection and its corresponding submodule.
  • Gated Identity Initialization (GII): Identity Map Reordering improves training stability because it initializes the agent close to a Markovian policy at the beginning of training. GII, by opportunely initializing the gating bias, explicitly initializes the various gating mechanisms to be close to the identity map so to emphasize the effect of IMR.

With the above architecture modifications, GTrXL shows a substantial and consistent improvement in memory-based environments such as DMLab-30 over the ubiquitous LSTM architectures and suffers no degradation of performance in reactive environments.

It can be used as backbone of many policies such as V-MPO (used in the original paper), with R2D2 replacing its RNN network, but it can also simply be combined with a DQN. Another benefit of the Transformer architecture is its ability to scale to very large and deep models and to effectively utilize this additional capacity in larger datasets, thus making it useful in offline RL settings or with large and complex environments.

Expanding GTrXL

GTrXL thus sets a baseline for Transformer in RL but also exists some more recent works aimed at further improving the vanilla GTrXL. Some worth mentioning include:

Adaptive Transformer in RL: published in 2020, just a year later after the work of Parisotto, Kumar et al. significantly increased the size of the memory of GTrXL and improved the attention by adding an Attention Adaptive Span.

This mechanism aims to learn an optimal attention span for each layer and each head, thus allowing the attention to selectively attend to past timesteps using much larger memory blocks. Experiments reported that lower layers learn a shorter attention span to capture local information and upper layers have large spans to capture global information. Having a shorter attention span allows the model to save memory and attend to a large memory only if needed.

This in turn comports to reducing the attention computation. In practice, this computational advantage is achieved by clipping the memory such that only memory cells that are not masked out are retained. GTrXL with Adaptive Attention Span retains similar performance with respect to its fixed memory length version in both memory and reactive environments at the gain of an increased training speed and consuming fewer resources.

Contrastive BERT (CoBERL): inspired by BERT, a bidirectional Transformer encoder, CoBERL is a reinforcement learning agent that combines a contrastive loss and a hybrid LSTM-Transformer gated architecture to tackle the challenge of improving data efficiency for RL, a problem that is particularly relevant when working on high dimensional inputs.

It uses bidirectional masked prediction from BERT in combination with a generalization of recent contrastive methods to learn better and more consistent representations for transformers in RL, without the need for hand-engineered data augmentations or domain knowledge. Moreover, combining GTrXL with LSTMs using a trainable gate allows the agent to learn to exploit the representations offered by the Transformer only when an environment requires it, and avoid the extra complexity when not needed.

CoBERL consistently improves performance over its predecessor across the full Atari suite, a set of control tasks, and challenging 3D environments.

Transformer in offline-RL

After the outstanding results that were achieved in combination with online RL policies, Transformer also started to challenge existing offline RL methods. In offline RL, the agent can’t interact with the environment, but it is only provided with a limited dataset containing suboptimal trajectories to learn its task.

Under these settings, training is traditionally challenging due to error propagation and value overestimation. Some of the most relevant algorithms adopting Transformer are:

Decision Transformer: presented by Chen et al., this framework views RL as a conditional sequence modeling problem. Without a value or policy function, Decision Transformer simply outputs the optimal actions by leveraging a causally masked Transformer encoder.

Trajectories of states, actions, and returns-to-go are first encoded through their corresponding embedding layers and positional episodic timestep encodings. Encoded tokens are then fed into a Transformer architecture which predicts actions auto-regressively using a causal self-attention mask. During evaluations, returns-to-go are substituted with the desired cumulative return so that Decision Transformer model can generate future actions that achieve it.

Despite its simplicity, Decision Transformer matches or exceeds the performance of state-of-the-art model-free offline RL baselines on Atari, OpenAI Gym, and Key-to-Door tasks with remarkably superior performance when long-term credit assignment is required.

Trajectory Transformer: as a concurrent work to Decision Transformer, also Trajectory Transformer views an RL problem as a sequence modeling problem, but it differs in its way of training which focuses on model distributions over trajectories. The input consists of unstructured trajectories of states, actions, and rewards. State and actions are discretized along each dimension independently to model the distribution over trajectories with more expressivity.

Training is performed auto-regressively supported by a teacher-forcing procedure and beam search to train sequence models. Trajectory Transformer can be repurposed for imitation learning, goal-conditioned reinforcement learning, and offline reinforcement learning by including minor modifications between each set.

Some final thoughts

These works showed that Transformer can be used for RL: in reactive environments, Transformer can match the performance of standard policy gradients or Q-value algorithms while in those environments where memory is required Transformer can even surpass agents endowed with an RNN architecture under different RL settings.

While RNN usually puts more emphasis on the most recent states, the attention mechanism can automatically learn which parts of a sequence should be given more importance independently of their absolute position. In addition, Transformer can better exploit temporal dependencies among states and actions in their trajectories, and learn better representations to predict the next action.

Nevertheless, Transformers-based agents still bring back some of their typical and well-known shortcomings: they are currently slower and more resource-intensive than single-step models often used in model-free control, requiring up to multiple seconds for action selection when the context window grows too large. This precludes real-time control with standard Transformers for most dynamic systems. Moreover, there is not a strong advantage in using a Transformer-architecture when having to deal with purely reactive environments where it doesn’t bring any benefits in terms of performance and speed.


Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s