Machine Translation Using RNN
Last updated
Last updated
One of the cool things that we can use RNNs for is to translate text from one language to another. In the past this was done using hand crafted features and lots of complex conditions which took a very long time to create and were complex to understand. So let's see how RNNs make life easier and do a better job.
The simplest idea is to use a basic RNN such as the following diagram below. In this diagram the RNN is unrolled to make it easier to understand whats happening.
This type of RNN is a sequence to sequence RNN (seq2seq).
We first compute our word embeddings for the input one hot word vectors. Then we send in the embedded words to translate. Our (embedded word vector) is multiplied by some weight matrix W(hx). Our previous calculated hidden state (which is the previous output of RNN node) is multiplied by a different weight matrix W(hh). The results of these 2 multiplications are then added together and non linearity like Relu/tanh is applied. This is now our next hidden state h.
This process is then repeated for the length of our input sentence. Obviously on the first input word x0 there is no previous hidden state so we just set this h0 to be all zeros.
Also note that our sentence could be different lengths so we must also have a stop token (e.g a full stop) which indicates we have reached the end of the sentence. We hope that our model will learn when to predict this stop token for the output. This stop token is basically just an extra 'word' in our training data.
Once we reach the stop token we go the decoder RNN node to start producing output vectors.
To get the output y at each time step from the decoder RNN we have another weight matrix W(S) that we multiply our hidden state h to get a vector output. A softmax is then applied to this which gives us our final output. This final output tells you what word vector was predicted at that time step.
This model is very simple and in practice might only work for very very sort sentences (2-3 words). This is because these basic RNN's have trouble remembering more than a few steps back in the past (vanishing gradient problem).
So what changes to this simple model can we do to start to make things better?
Train different RNN weights for encoding and decoding. In the model above we use the same RNN node doing both encoding and decoding which is clearly not optimal for learning.
At decode stage rather than just have the previous hidden stage as input, like we do above, we now also include the last hidden stage of the encoder (we call this C in the diagram below). Along with this we also include the last predicted output word vector. This should help the model to know it just output a certain word and not to output that word again.
So the decoder node will now have 3 weight matrices in it ( one of previous hidden state h, one for last predicted word vector y and one for the last encoder hidden state c) which we multiply the corresponding input by and then add up to get our decode output.
In the decode stage it is a different weight matrix to the encode stage but again this weight matrix is used through the decode stage.
Make it deeper! Add more RNN layers to your model.
Maybe train a bidirectional encoder. So take the hidden state not only from the last hidden layer time step but also from the next hidden layer time step.
An alternative to training a bidirectional encoder is to train the encoder in reverse order. What this means is if normally you train with words A,B,C go to X Y we instead train with words C,B,A go to X, Y. This is a simpler optimisation problem as you bring words that are being translated closer together hence less vanishing gradient issues. (This only works if languages align well like French and English)
That is probably as good as we can get using just the simple RNN unit so if we want to do better we need to use a different type of RNN unit called Gated Recurrent Units (GRU for short).
The basic idea of the GRU is to keep around memories that keep long distance dependencies. Your model will learn to do this. This will also allow error messages to flow at different strengths depending on the input.
A different way of viewing the whole process in action is in this diagram below. Start from the bottom left and work up and across till you get to the red circle. This then goes to the left and up and follows through till the stop token (full stop) is predicted.Important to remember that the weight matrix that is used to multiply the inputs in each step of the encoder is the exact same, it is not different for different time steps.