Main Content

Classify Text Data Using Deep Learning

This example shows how to classify text data using a deep learning long short-term memory (LSTM) network.

Text data is naturally sequential. A piece of text is a sequence of words, which might have dependencies between them. To learn and use long-term dependencies to classify sequence data, use an LSTM neural network. An LSTM network is a type of recurrent neural network (RNN) that can learn long-term dependencies between time steps of sequence data.

To input text to an LSTM network, first convert the text data into numeric sequences. You can achieve this using a word encoding which maps documents to sequences of numeric indices. For better results, also include a word embedding layer in the network. Word embeddings map words in a vocabulary to numeric vectors rather than scalar indices. These embeddings capture semantic details of the words, so that words with similar meanings have similar vectors. They also model relationships between words through vector arithmetic. For example, the relationship "Rome is to Italy as Paris is to France" is described by the equation Italy Rome + Paris = France.

There are four steps in training and using the LSTM network in this example:

  • Import and preprocess the data.

  • Convert the words to numeric sequences using a word encoding.

  • Create and train an LSTM network with a word embedding layer.

  • Classify new text data using the trained LSTM network.

Import Data

Import the factory reports data. This data contains labeled textual descriptions of factory events. To import the text data as strings, specify the text type to be 'string'.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

The goal of this example is to classify events by the label in the Category column. To divide the data into classes, convert these labels to categorical.

data.Category = categorical(data.Category);

View the distribution of the classes in the data using a histogram.

title("Class Distribution")

The next step is to partition it into sets for training and validation. Partition the data into a training partition and a held-out partition for validation and testing. Specify the holdout percentage to be 20%.

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

Extract the text data and labels from the partitioned tables.

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

To check that you have imported the data correctly, visualize the training text data using a word cloud.

title("Training Data")

Preprocess Text Data

Create a function that tokenizes and preprocesses the text data. The function preprocessText, listed at the end of the example, performs these steps:

  1. Tokenize the text using tokenizedDocument.

  2. Convert the text to lowercase using lower.

  3. Erase the punctuation using erasePunctuation.

Preprocess the training data and the validation data using the preprocessText function.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

View the first few preprocessed training documents.

ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses

Convert Document to Sequences

To input the documents into an LSTM network, use a word encoding to convert the documents into sequences of numeric indices.

To create a word encoding, use the wordEncoding function.

enc = wordEncoding(documentsTrain);

The next conversion step is to pad and truncate documents so they are all the same length. The trainingOptions function provides options to pad and truncate input sequences automatically. However, these options are not well suited for sequences of word vectors. Instead, pad and truncate the sequences manually. If you left-pad and truncate the sequences of word vectors, then the training might improve.

To pad and truncate the documents, first choose a target length, and then truncate documents that are longer than it and left-pad documents that are shorter than it. For best results, the target length should be short without discarding large amounts of data. To find a suitable target length, view a histogram of the training document lengths.

documentLengths = doclength(documentsTrain);
title("Document Lengths")
ylabel("Number of Documents")

Most of the training documents have fewer than 10 tokens. Use this as your target length for truncation and padding.

Convert the documents to sequences of numeric indices using doc2sequence. To truncate or left-pad the sequences to have length 10, set the 'Length' option to 10.

sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
ans=5×1 cell array
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}

Convert the validation documents to sequences using the same options.

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

Create and Train LSTM Network

Define the LSTM network architecture. To input sequence data into the network, include a sequence input layer and set the input size to 1. Next, include a word embedding layer of dimension 50 and the same number of words as the word encoding. Next, include an LSTM layer and set the number of hidden units to 80. To use the LSTM layer for a sequence-to-label classification problem, set the output mode to 'last'. Finally, add a fully connected layer with the same size as the number of classes, a softmax layer, and a classification layer.

inputSize = 1;
embeddingDimension = 50;
numHiddenUnits = 80;

numWords = enc.NumWords;
numClasses = numel(categories(YTrain));

layers = [ ...
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 50 dimensions and 423 unique words
     3   ''   LSTM                    LSTM with 80 hidden units
     4   ''   Fully Connected         4 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

Specify Training Options

Specify the training options:

  • Train using the Adam solver.

  • Specify a mini-batch size of 16.

  • Shuffle the data every epoch.

  • Monitor the training progress by setting the 'Plots' option to 'training-progress'.

  • Specify the validation data using the 'ValidationData' option.

  • Suppress verbose output by setting the 'Verbose' option to false.

By default, trainNetwork uses a GPU if one is available. Otherwise, it uses the CPU. To specify the execution environment manually, use the 'ExecutionEnvironment' name-value pair argument of trainingOptions. Training on a CPU can take significantly longer than training on a GPU. Training with a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox).

options = trainingOptions('adam', ...
    'MiniBatchSize',16, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...

Train the LSTM network using the trainNetwork function.

net = trainNetwork(XTrain,YTrain,layers,options);

Predict Using New Data

Classify the event type of three new reports. Create a string array containing the new reports.

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Preprocess the text data using the preprocessing steps as the training documents.

documentsNew = preprocessText(reportsNew);

Convert the text data to sequences using doc2sequence with the same options as when creating the training sequences.

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

Classify the new sequences using the trained LSTM network.

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Electronic Failure 
     Mechanical Failure 

Preprocessing Function

The function preprocessText performs these steps:

  1. Tokenize the text using tokenizedDocument.

  2. Convert the text to lowercase using lower.

  3. Erase the punctuation using erasePunctuation.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);


See Also

| | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

Related Topics