SeqGAN: text generation with generative models

In this post we propose to review recent history of research in the Natural Language Generation (NLG) tasks of the Natural Language Processing domain. Realistic human-like language generation has been a challenge for researches that has recently come into greater focus with the release of large neural models for NLP like the GPT and BERT models. In this post we propose to focus ourselves on GAN based models and provide an overview of the strengths in each model and the problems faced by them. We will also take a look at a few specific problems important to researchers in NLG (Natural Language Generation) and will illustrate how they were overcome. NLG is a task that has recently become tractable and we feel a targeted review of the latest trends will help develop a big picture understanding and will further active research.

Introduction

The ability to generate coherent and semantically meaningful text plays a key role in many natural language processing applications such as machine translation, dialogue generation, image captioning, text summarization, story telling, and poetry generation. Most of the initial work focuses on task-specific applications in supervised settings but the generic unsupervised text generation, which aims to mimic the distribution over real text from a corpus, has recently drawn much attention. Thus, we think that deep learning text generation is still in its childhood respect to other more mature NLP tasks such as text classification and knowledge graphs which have been studied in greater detail. A typical approach is to train a recurrent neural network (RNN) to maximize the log-likelihood of each ground-truth word given prior observed words, which, however, suffers from so-called exposure bias due to the discrepancy between training and inference stage: the model sequentially generates the next word based on previously generated words during inference but itself is trained to generate words given ground-truth words. This problem has been addressed using a curriculum learning strategy to gently change the training process from a fully guided scheme using the true previous token (teacher forcing), towards a less guided scheme which mostly uses the generated token instead, but is proved to be fundamentally inconsistent. Moreover RNN language model are unable to learn a vector representation of the full sentence due to their sequential nature, they do not expose an interpretable representation of global features like style or topic.

Reinforcement learning for text generation

Another possible approach consists to use RL policy-gradient algorithms such as REINFORCE to train a model optimizing some non-differentiable metric such a BLEU. However, this technique is not useful in practice since BLEU is a computationally expensive metric, and even not a strong one as it just counts the n-gram statistics similarity between the generated text and the reference corpus. Generative Adversarial Nets (GAN), which was firstly proposed for continuous data (image generation etc.), is then extended to discrete, sequential data to perform text generation and has shown promising results. Due to the discrete nature of text samples, text generation is modeled as a sequential decision making process, where the state is previously generated words, the action is the next word to be generated, and the generative net G is a stochastic policy that maps current state to a distribution over the action space. After the whole text generation is done, the generated text samples are then fed to the discriminative net D, a classifier that is trained to distinguish real and generated text samples, to get reward signals for updating G. However, in natural languages processing, the text sequences are evaluated as the discrete tokens whose values are non-differentiable. Therefore, the optimization of GANs is challenging. Hence, rather than using SGD, NLP-oriented GAN models are usually trained with RL techniques such as policy gradient such as RankGAN and SeqGAN, or exploiting more advanced hierarchical RL algorithms as for the case of LeakGAN (adversarial reinforcement learning).

Generative Adversarial Networks

Generative Adversarial Networks (GAN), are some kind of networks whose goal is to generate new samples from a distribution that is as similar as possible to the distribution of the original data by minimizing a minimax adversarial function L defined as

min_G \, max_D V(D,G)=E_{x\sim p_{data(x)}}[logD(x)]+ E_{z\sim p_{z(z)}}[log(1-D(G(z)))].

GANs are usually composed of two models: a generator G and a discriminator D. In this post, we are going to refer to the parameters of the generator as θ, and to the parameters of the discriminator as φ. Thus, the objective of the generator is to find some parameters θ such that the distance between the estimated distribution and the distribution of the actual data is minimized under the supervision of the discriminator. During the training process, the generator will learn to generate new samples from a distribution resembling the original distribution and the discriminator will learn to distinguish between generated and real samples. Hopefully, both networks will improve in their task such that the G could generate new samples (text in our case) as similar as possible to human written text. Unfortunately, applying GAN to generating sequences has two problems. The first one is that GAN is designed for generating continuous data, but texts, which can be seen as sequences of discrete tokens are difficult to be generated. Secondly, GAN usually gives the score/loss for an entire sequence only once it has been completely generated. Thus, for a partially generated sequence, it is non-trivial to determine how good it is at the present time and its future score when the sequence will be completed.

RL-based GAN for text generation

To address the two problems defined in the previous section, we can consider the sequence generation procedure as a sequential decision making process. We can consider the generator G as a reinforcement learning agent, the state of the environment as the tokens that have been generated so far, and the next action as the next token of the sequence to be generated. Regarding the reward, it is employed a discriminator D to evaluate the generated sequence at the end of each episode via Monte Carlo approach so to guide the learning process of the generator. The reward is estimated by the likelihood that the generated sentence would fool the discriminator. To solve the problem that the gradient can’t flow back to the generative model when the output is discrete, it can be directly trained the policy of the generative model via RL techniques such as policy gradient, which naturally overcomes the difficulty of differentiating discrete data in a conventional GAN. In this way, the generator would be trained to picks the optimal action (word) and learns the policy using estimated overall rewards. The discriminative model D, similarly to the original GAN, is trained by providing positive examples from the real sequence data and negative examples from the synthetic sequences generated from the generative model G. As for the implementation, the generator could be a simple recurrent neural network (RNN) such as LSTM or GRU so to deal with the common vanishing and exploding gradient problem of the backpropagation through time. An RNN maps the input embedding representations of a sequence x1,…,xt into a sequence of hidden states h1,…,ht using a recurrent function and then outputs an output token distribution applying a softmax function over the last hidden state. The discriminator could consists of a CNN network convolving over the k-dimensional token embedding taking advantage of multiple kernels and max-pooling over time to extract multiple features. Finally, a fully connected layer with sigmoid activation is used to output the probability that the input sequence is real.

In the next section, we are going to view in details the architecture of one of those models, SeqGAN typically used for text generation.

SeqGAN

Text generation problem can be formulated as a sequence generation task denoted as follows. Given a corpus containing real-world structured sequences, train a θ-parameterized generative model Gθ to produce a sequence Y1:T=y1,…,yt,…,yT, yt ∈ Y), where Y is the vocabulary of candidate tokens. Interpreting the problem as a reinforcement learning scenario, we have that in timestep t, the state st is defined by the current produced tokens y1,…,yt-1 until timestep t-1, and the action a is the next token yt to be generated, which is selected from Y by the generator with probability Gθ(yt|Y1:t-1). Thus, the objective of the generator model is to generate a sequence from the start state s0 to maximize its expected end reward as

L(\theta)=E[R_T|s_0,t]=\sum_{y\in Y}G_{\theta}(y|s_0)\cdot Q_{D_{\phi}}^{G_{\theta}}(s_0,y),

where RT is the reward for a complete sequence. QGθ(s0, y) is the action-value function of a sequence, that is, the expected accumulative reward starting from state s0, taking action a, and then following policy G. To estimate the action-value function, it is used the REINFORCE algorithm where the reward is estimated as the probability of the generated sentence being true by the discriminator as

Q_{D_{\phi}}^{G_{\theta}}(a=s_{T},s=Y_{1:T-1})=D_{\theta}(Y_{1:T}).

Since we actually need the reward at every timestep, and not only at the last time when the sequence is finished, we apply a Monte Carlo search with a roll-out policy G (the same policy defined by the generator) to sample the unknown last T−t tokens so to get the long-term reward for partially complete sentences Y1:t-1, with t != T. We define the Monte Carlo search from Y1:t-1 to Y1:t+n with y+n<=T as

{Y_{1:T}^1,...,Y_{1:T}^N}=MC^G(Y_{1:T};N),

where Yt+1:TN is sampled based on the roll-out policy G and the current state st. To reduce the variance and get more accurate assessment of the action value, we run the roll-out policy starting from current state until the end of the sequence for N timesteps to get a batch of output samples as

Q_{D_{\phi}}^{G_{\theta}}(a=s_{y},s=Y_{1:t-1})=

\frac{1}{N}\sum_{n=1}^N D_{\phi}(Y_{1:T}^n),\: Y_{1:T}^n \in MC^G (Y_{1:T};N) \:\: if \:\: t<T;

D_{\phi}(Y_{1:T}) \:\: if \:\: t=T.

Once we have a set of realistic generated sequences, we shall re-train the discriminator model as follows

min_\phi -E_{Y\sim p_{data}}[logD_{\phi}(Y)]-E_{Y\sim G_{\theta}}[log(1-D_{\phi}(Y))].

Accordingly, the gradient of the objective function L(θ) of the generator can be derived as

\nabla_{\theta} L(\theta)=\sum_{t=1}^T E_{Y_{1:t-1}\sim G_{\theta}}[\sum_{t_t \in Y}\nabla_{\theta}G_{\theta}(y_t|Y_{1:t-1})\cdot Q_{D_{\phi}}^{G_{\theta}}(Y_{1:t-1},y_t)].

The generator and discriminator are trained alternatively. As the generator gets progressed via training on g-steps updates, the discriminator needs to be retrained periodically to keeps a good pace with the generator. When training the discriminator, positive examples are from the given dataset S, whereas negative examples are generated from our generator. In order to keep the balance, the number of negative examples we generate for each d-step is the same as the positive examples.

Representation of SeqGAN. Left: D is trained over the real data and the generated data by G. Right: G is trained by policy gradient where the final reward signal is provided by D and is passed back to the intermediate action value via Monte Carlo search.

Metrics: Self-BLEU and NLL-test loss

BLEU (Bilingual Evaluation Understudy) is the metric that is usually used to compare the performance of text generation models. This metric, originally developed for evaluating the predictions made by automatic machine translation systems, is commonly used for evaluating a generated sentence respect to a reference sentence. It is computed comparing n-grams of the generated text with the n-grams of the reference text counting the number of position-independent matches. The more the matches, the higher the value of the metric. The highest achievable score is 1, which will be reached only in case the candidate text is identical to one of the reference text. The upper-bound is 0 which is obtained only in case no matches are present. More specifically, BLEU(N) means that n-grams of length N are used for evaluation (2-grams for BLEU2, 3-grams for BLEU3 and so on). The final BLEU score is computed as the average BLEU score of each generated sentence.

NLL-test loss is dual to NLL-oracle loss and it is used to evaluate the model’s capacity to fit real test data. It is computed as the average negative log-likelihood of real test data from the generator as

NLL_{test}=-E_{Y_{1}:T\sim G_{real}} [\sum_{t=1}^t log(G_{\theta}(y_t|Y_{1:t-1}))],

where Greal denotes the distribution of real data. NLLtest can only be applied to autoregressive generator like RNN since Gθ(yt|Y1:t-1) is involved to calculate the likelihood of certain word based on previous ones given a generator.

Datasets

In this section we are going to briefly introduce some datasets usually employed to test text generation models:

  • Penn Treebank: A common evaluation datasert for language modeling (task of predicting the next word or character in a document). The dataset consists of 929k training words, 73k validation words, and 82k test words. Models are evaluted based on perplexity, which is the average per-word log-probability (lower is better);
  • Image COCO: Image COCO dataset is proposed for object detection, segmentation and image captioning tasks. It contains images, bounding boxes, annotations and labels, but, in order to compare our models, only its image caption annotations are used. It contains 4,682 distinct words and the maximum length of a sentence is 37. Sentences in this dataset have relatively short and simple patterns and can then be considered as a short text generation dataset;
  • EMNLP2017 WMT News: It contains news article sentences taken from different corpus. Considering the fact that most sentences contain rarely used words, are only kept that sentences containing only the most commonly used 5,700 words. The maximum sentence length is 51 and can be considered as a long text generation dataset.

Generated sentences

In this section, we are going to show some sentences generated by SeqGAN on the COCO dataset and compare it with sentences generated by a simple MLE model.

SeqGAN generated sentences:

  • a very tall pointy across the street
  • a bowl full of various cooking in to black kitchen
  • a parked car with a woman hanging over a motorcycle .
  • a bowl full monitor with a monitor next to a couple is painted painted in it .
  • his with an image of a white toilet
  • an image of a motorcycle decorated with tall trees

MLE generated sentences:

  • there are tiled hanging across a large the toilet .
  • a large airplane is on the runway covered above of trees .
  • a sink filled with clutter and wooden cabinets .
  • a white towel on a stool above a blue counter top oven .
  • young girl laying on a table in front of a large hotel
  • a man bending while on a motorcycle with pots on a track travelling down on it

The sentences generated by the MLE model have no or little sense but are quite grammatically correct, while SeqGAN generates generally correct sentences but that sometimes don’t reflect the topic of the given images.

Conclusion

Text generation is a branch of NLP that, although still in a nascent stage, has had many breakthroughs in recent years. Given the usefulness and importance of human-like text generation of this area, many new challenges and opportunities are expected to pop up in coming research. Deep RL already proven to outperform humans in other tasks such as playing the game of GO or mastering multiple Atari games at a superhuman level and we expect
that in the future similar breakthroughs will follow through to the field of text generation. Thus, we think that reviewing recent highlights in this field will further promote exploration of a combination of reinforcement learning and generative models, thus finding a better way to integrate them and exploiting the state-of-the-art techniques from both fields.

References

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s