layout | permalink |
---|---|
page |
/rnn/ |
Table of Contents:
- Introduction to RNN
- RNN example as Character-level language model
- Multilayer RNNs
- Long-Short Term Memory (LSTM)
In this lecture note, we're going to be talking about the Recurrent Neural Networks (RNNs). One great thing about the RNNs is that they offer a lot of flexibility on how we wire up the neural network architecture. Normally when we're working with neural networks (Figure 1), we are given a fixed sized input vector (red), then we process it with some hidden layers (green), and we produce a fixed sized output vector (blue) as depicted in the leftmost model ("Vanilla" Neural Networks) in Figure 1. While "Vanilla" Neural Networks receive a single input and produce one label for that image, there are tasks where the model produce a sequence of outputs as shown in the one-to-many model in Figure 1. Recurrent Neural Networks allow us to operate over sequences of input, output, or both at the same time.
- An example of one-to-many model is image captioning where we are given a fixed sized image and produce a sequence of words that describe the content of that image through RNN (second model in Figure 1).
- An example of many-to-one task is action prediction where we look at a sequence of video frames instead of a single image and produce a label of what action was happening in the video as shown in the third model in Figure 1. Another example of many-to-one task is sentiment classification in NLP where we are given a sequence of words of a sentence and then classify what sentiment (e.g. positive or negative) that sentence is.
- An example of many-to-many task is video-captioning where the input is a sequence of video frames and the output is caption that describes what was in the video as shown in the fourth model in Figure 1. Another example of many-to-many task is machine translation in NLP, where we can have an RNN that takes a sequence of words of a sentence in English, and then this RNN is asked to produce a sequence of words of a sentence in French.
- There is a also a variation of many-to-many task as shown in the last model in Figure 1, where the model generates an output at every timestep. An example of this many-to-many task is video classification on a frame level where the model classifies every single frame of video with some number of classes. We should note that we don't want this prediction to only be a function of the current timestep (current frame of the video), but also all the timesteps (frames) that have come before this video.
In general, RNNs allow us to wire up an architecture, where the prediction at every single timestep is a function of all the timesteps that have come before.
The existing convnets are insufficient to deal with tasks that have inputs and outputs with variable sequence lengths. In the example of video captioning, inputs have variable number of frames (e.g. 10-minute and 10-hour long video) and outputs are captions of variable length. Convnets can only take in inputs with a fixed size of width and height and cannot generalize over inputs with different sizes. In order to tackle this problem, we introduce Recurrent Neural Networks (RNNs).
RNN is basically a blackbox (Left of Figure 2), where it has an “internal state” that is updated as a sequence is processed. At every single timestep, we feed in an input vector into RNN where it modifies that state as a function of what it receives. When we tune RNN weights, RNN will show different behaviors in terms of how its state evolves as it receives these inputs. We are also interested in producing an output based on the RNN state, so we can produce these output vectors on top of the RNN (as depicted in Figure 2).
If we unroll an RNN model (Right of Figure 2), then there are inputs (e.g. video frame) at different timesteps shown as
More precisely, RNN can be represented as a recurrence formula of some function
where at every timestep it receives some previous state as a vector
In the most simplest form of RNN, which we call a Vanilla RNN, the network is just a single hidden
state
We can base predictions on top of
So far we have showed RNN in terms of abstract vectors
One of the simplest ways in which we can use an RNN is in the case of a character-level language model since it's intuitive to understand. The way this RNN will work is we will feed a sequence of characters into the RNN and at every single timestep, we will ask the RNN to predict the next character in the sequence. The prediction of RNN will be in the form of score distribution of the characters in the vocabulary for what RNN thinks should come next in the sequence that it has seen so far.
So suppose, in a very simple example (Figure 3), we have the training sequence of just one string
As shown in Figure 3, we'll feed in one character at a time into an RNN, first
Then we're going to use the recurrence formula from the previous section at every single timestep.
Suppose we start off with
As we apply this recurrence at every timestep, we're going to predict what should be the next character
in the sequence at every timestep. Since we have four characters in vocabulary
As shown in Figure 3, in the very first timestep we fed in
where RNN thinks that the next character
So far we have only shown RNNs with just one layer. However, we're not limited to only a single layer architectures. One of the ways, RNNs are used today is in more complex manner. RNNs can be stacked together in multiple layers, which gives more depth, and empirically deeper architectures tend to work better (Figure 4).
For example, in Figure 4, there are three separate RNNs each with their own set of weights. Three RNNs are stacked on top of each other, so the input of the second RNN (second RNN layer in Figure 4) is the vector of the hidden state vector of the first RNN (first RNN layer in Figure 4). All stacked RNNs are trained jointly, and the diagram in Figure 4 represents one computational graph.
So far we have seen only a simple recurrence formula for the Vanilla RNN. In practice, we actually will rarely ever use Vanilla RNN formula. Instead, we will use what we call a Long-Short Term Memory (LSTM) RNN.
An RNN block takes in input
For the back propagation, Let's examine how the output at the very last timestep affects the weights at the very first time step.
The partial derivative of
We update the weights
-
Vanishing gradient: We see that
$$tanh^{'}(W_{hh}h_{t-1} + W_{xh}x_t)$$ will almost always be less than 1 because tanh is always between negative one and one. Thus, as$$t$$ gets larger (i.e. longer timesteps), the gradient ($$\frac{\partial L_{t}}{\partial W} $$ ) will descrease in value and get close to zero. This will lead to vanishing gradient problem, where gradients at future time steps rarely impact gradients at the very first time step. This is problematic when we model long sequence of inputs because the updates will be extremely slow. -
Removing non-linearity (tanh): If we remove non-linearity (tanh) to solve the vanishing gradient problem, then we will be left with
$$ \begin{aligned} \frac{\partial L_{t}}{\partial W} = \frac{\partial L_{t}}{\partial h_{t}}(\prod_{t=2}^{T} W_{hh}^{T-1})\frac{\partial h_{1}}{\partial W} \end{aligned} $$- Exploding gradients: If the largest singular value of W_{hh} is greater than 1, then the gradients will blow up and the model will get very large gradients coming back from future time steps. Exploding gradient often leads to getting gradients that are NaNs.
- Vanishing gradients: If the laregest singular value of W_{hh} is smaller than 1, then we will have vanishing gradient problem as mentioned above which will significantly slow down learning.
In practice, we can treat the exploding gradient problem through gradient clipping, which is clipping large gradient values to a maximum threshold. However, since vanishing gradient problem still exists in cases where largest singular value of W_{hh} matrix is less than one, LSTM was designed to avoid this problem.
The following is the precise formulation for LSTM. On step
At every timestep we have an input vector
where
Since all
-
Forget Gate: Forget gate
$$f_t$$ at time step$$t$$ controls how much information needs to be "removed" from the previous cell state$$c_{t-1}$$ . This forget gate learns to erase hidden representations from the previous time steps, which is why LSTM will have two hidden represtnations$$h_t$$ and cell state$$c_t$$ . This$$c_t$$ will get propagated over time and learn whether to forget the previous cell state or not. -
Input Gate: Input gate
$$i_t$$ at time step$$t$$ controls how much information needs to be "added" to the next cell state$$c_t$$ from previous hidden state$$h_{t-1}$$ and input$$x_t$$ . Instead of tanh, the "input" gate$$i$$ has a sigmoid function, which converts inputs to values between zero and one. This serves as a switch, where values are either almost always zero or almost always one. This "input" gate decides whether to take the RNN output that is produced by the "gate" gate$$g$$ and multiplies the output with input gate$$i$$ . -
Output Gate: Output gate
$$o_t$$ at time step$$t$$ controls how much information needs to be "shown" as output in the current hidden state$$h_t$$ .
The key idea of LSTM is the cell state, the horizontal line running through between recurrent timesteps. You can imagine the cell
state to be some kind of highway of information passing through straight down the entire chain, with
only some minor linear interactions. With the formulation above, it's easy for information to just flow
along this highway (Figure 5). Thus, even when there is a bunch of LSTMs stacked together, we can get an uninterrupted gradient flow where the gradients flow back through cell states instead of hidden states
This greatly fixes the gradient vanishing/exploding problem we have outlined above. Figure 5 also shows that gradient contains a vector of activations of the "forget" gate. This allows better control of gradients values by using suitable parameter updates of the "forget" gate.
LSTM architecture makes it easier for the RNN to preserve information over many recurrent time steps. For example, if the forget gate is set to 1, and the input gate is set to 0, then the infomation of the cell state will always be preserved over many recurrent time steps. For a Vanilla RNN, in contrast, it's much harder to preserve information in hidden states in recurrent time steps by just making use of a single weight matrix.
LSTMs do not guarantee that there is no vanishing/exploding gradient problems, but it does provide an easier way for the model to learn long-distance dependencies.