This post aims to visualize the loss landscape of some imitation policies (IL policies) trained with GAIL, and their discriminator trained in three common environments: Cartpole, Lunarlander, and Walker2d from Mujoco. The expert policy of Cartpole and Lunarlander is a simple Double DQN while the expert of Walker2d, which supports continuous actions, is a DDPG policy. The imitation policies are the same policies employed by their expert but the rewards are provided by the GAIL discriminator and trajectories are sampled from an expert dataset, and not by interacting with the environment. The expert dataset has been obtained sampling trajectories with the trained expert policy.
A more comprehensive review of how the mechanics of GAIL work can be found in this post. However, we are going to review here some fundamentals concepts of this algorithm that we will need in order to better understand the experiments in this post.
The original formula to be optimized to train GAIL is:
min(Π) max(D) EΠ(log(D(s,a)) + EΠe(log(1-D(s,a)).
The input consists of tuples (s, a), where s is an observed state and a is its corresponding action. Trajectories of (s, a ) tuples are stored in the expert dataset, where expert indicates that those trajectories have been sampled using an expert policy and the goal of the imitation policy is to learn to imitate the expert. EΠ means that (s, a) is sampled from imitation-policy-generated trajectories and EΠe means that (s, a) comes from the expert. We can observe from the formula, that the discriminator D has to bring the value of D(s, a) close to 1 when sampling from EΠ and close to 0 when sampling from EΠe. When approaching this scenario, the loss function of the discriminator is close to its maximum. The policy Π, has to counter D by minimizing its loss function, that is, it has to bring D(s, a) close to 0. This is achieved only if EΠ is enough similar to EΠe. The reward Π is provided by D and is usually in the form of -log(D(s, a)). The reward is high when D(s, a) is close to 0 and ranges from 0 to +∞.
The following table reports the performance of the evaluated policies:
|Environment||expert mean reward||gail mean reward|
For the Cartpole and Lunarlander environments, the convergence rule has been defined as mean reward equal 200 over 100 consecutive episodes. This condition is enough to declare the environment solved, thus training terminated upon reaching the convergence value.
We also show the reward training curve of the IL policy (red) and expert policy (orange) for the Walker2d environment:
In reinforcement learning neural networks models are trained on a corpus of observations s and accompanying
actions a by minimizing a loss of the form L(w) which measures how well the neural network with parameters w predicts the correct action. L(w) is defined by the algorithm used to train the policy (Bellman equation in DQN or minimax loss in GAIL). Neural networks used to represent a policy contain many parameters, and so their loss functions live in a very high-dimensional space. That’ why we need appropriate method to visualize their loss in a 3D (surface) plot.
The method we used to visualize the loss landscape is called Contour Plots & Random Directions with Filter-Wise Normalization, more details can be found here and the relative implementation is based on https://github.com/marcellodebernardi/loss-landscapes. Following, we are going to briefly summarize how the method works.
The loss landscape is represented in 3D where the horizontal 2D plane corresponds to the space of the parameters of the model, and the center of the plane corresponds to the original parameters w* obtained after training the model. Moving along the xy axes, we move along the parameters space of the model until a max of n steps forward or backward for each direction defined as w1 and w2, where w1 and w2 are orthogonal and randomly generated, and they have the same size of w*. Each step changes the original parameters for an amount defined as (||w*|| / steps) / ||wm|| with m ∈ [1, 2] , and ||.|| is the euclidean distance. The z axis (vertical axis) represents the value of the loss according to the corresponding parameters w* + xw1 + yw2. Hence, the plot shows how the loss of the model varies according to a small change of w*. This is useful to visualize the convexity, smoothness of a loss function as well as individuate local minima. For a more comprehensive introduction, you can refer to the original paper linked above.
One important conclusion worth knowing that is claimed by the authors is: landscape geometry has a dramatic effect on generalization. First, note that visually flatter minimizers consistently correspond to lower test error, which further strengthens our assertion that filter normalization is a natural way to visualize loss function geometry. Second, we notice that chaotic landscapes result in worse training and test error, while more convex landscapes have lower error values. In fact, the most convex landscapes, generalize the best of all, and show no noticeable chaotic behavior. We can sum up this as: the more convex a loss landscape is, the better the model performs.
Referring to the above pictures, both the IL policy and discriminator trained on Cartpole show a very smooth surface with both models presenting a very clear and unique minima. Regarding a slightly more complex environment such as Lunarlander, the IL policy presents a large flat surface where the model can converge and only near at the edge of the surface the loss begins assuming a more irregular behavior. On the other hand, the surface of the discriminator is more chaotic, but we can still notice some local minima near the center. The visualizations of the imitation policy and the discriminator trained on Walker2d are both quite irregular, and coincidentally, the two shapes seems to be almost complementary. Both landscapes also have a steep narrow minima close to the center.
Understanding the discriminator
Comparing the loss landscapes of the discriminator of Cartpole and Lunarlander, we would expect that a more chaotic landscape as in the case of Lunarlander would cause the model to not to be able to maintain its accuracy when encountering new observations that were not present in the expert dataset used for training.
In the next section, we compared the output of the discriminator D(s, a) during training with respect to 2 expert datasets: the first one is the same expert dataset used for training and the second one is a different dataset containing trajectories collected with a different expert whose performance is sub-optimal respect to the first expert. In this way the second dataset contains trajectories that the IL policy has never encountered during training. The concept is the same as training and validation dataset used in literature to train models and choose the version that performs the best.
Note: Let’s remind that a well-trained policy will have its discriminator converge to 0 and a random discriminator will output values close to 0.5. Trajectories coming from different sources should be classified as 1.
In both plots we can observe the output distribution being rather random at the beginning and gradually approaching close to 0 as training progresses. However, the trend is not the same between the left and the right plot.
In fact, in the rigth plot, which depicts the discriminator output given novel trajectories as input, a significant portion of its outputs is close to 1, thus meaning that the discriminator can successfully differentiate the source of the 2 trajectories. However, there are still some trajectories that the discriminator classifies as belonging to the expert used for training (classified as 0). This is due to two possible facts: the first one is that some trajectories collected by the sub-optimal experts are still similar to the trajectories of the training expert, and thus classified accordingly. Nevertheless, it is also possible that the discriminator fails to classify some trajectories due to lack of generalization as also suggested by the corresponding irregular loss landscape. However, a more effective model architecture or employing a more sophisticated policy may fix this flaw and improve accuracy the accuracy in the left plot and generalization in the right plot.
- GAIL implementation is based on the framework DI-engine.
- Loss landscape visualization is based on the repository: https://github.com/marcellodebernardi/loss-landscapes