A recurrent neural network (RNN) is a deep learning network structure that uses information of the past to improve the performance of the network on current and future inputs. What makes RNNs unique is that the network contains a hidden state and loops. The looping structure allows the network to store past information in the hidden state and operate on sequences.
These features of recurrent neural networks make them well suited for solving a variety of problems with sequential data of varying length such as:
- Natural language processing
- Signal classification
- Video analysis
How does the RNN know how to apply the past information to the current input? The network has two sets of weights, one for the hidden state vector and one for the inputs. During training, the network learns weights for both the inputs and the hidden state. When implemented, the output is based on the current input, as well as the hidden state, which is based on previous inputs.
In practice, simple RNNs experience a problem with learning longer-term dependencies. RNNs are commonly trained through backpropagation, where they can experience either a ‘vanishing’ or ‘exploding’ gradient problem. These problem cause the network weights to either become very small or very large, limiting the effectiveness of learning the long-term relationships.
A special type of recurrent neural network that overcomes this issue is the long short-term memory (LSTM) network. LSTM networks use additional gates to control what information in the hidden cell makes it to the output and the next hidden state. This allows the network to more effectively learn long-term relationships in the data. LSTMs are a commonly implemented type of RNN.
MATLAB® has a full set of features and functionality to train and implement LSTM networks with text, image, signal, and time series data. The next sections will explore the applications of RNNs and some examples using MATLAB.
Applications of RNNs
Natural Language Processing
Language is naturally sequential, and pieces of text vary in length. This makes RNNs a great tool to solve problems in this area because they can learn to contextualize words in a sentence . One example includes sentiment analysis, a method for categorizing the meaning of words and phrases. Machine translation, or the use of an algorithm to translate between languages, is another common application. Words first need to be converted from text data into numeric sequences. An effective way of doing this is a word embedding layer. Word embeddings map words into numeric vectors. The example below uses word embeddings to train a word sentiment classifier, displaying the results with the MATLAB wordcloud function.
In another classifier example, MATLAB uses RNNs to classify text data to determine the type of manufacturing failure. MATLAB is also used in a machine translation example to train a network to understand Roman numerals.
Signals are another example of naturally sequential data, as they are often collected from sensors over time. It is useful to automatically classify signals, as this can decrease the manual time needed for large datasets or allow classification in real time. Raw signal data can be fed into deep networks or pre-processed to focus on other features such as frequency components. Feature extraction can greatly improve network performance, as in an example with electrical heart signals. Below is an example using raw signal data in an RNN.
RNNs work well for videos because videos are essentially a sequence of images. Similar to working with signals, it helps to do feature extraction before feeding the sequence into the RNN. In this example, a pretrained GoogleNet model (a convolutional neural network) is used for feature extraction on each frame. You can see the network architecture below.