This example shows how to compare latent Dirichlet allocation (LDA) solvers by comparing the goodness of fit and the time taken to fit the model.
To reproduce the results of this example, set
Load the example data. The file
weatherReports.csv contains weather reports, including a text description and categorical labels for each event. Extract the text data from the field
filename = "weatherReports.csv"; data = readtable(filename,'TextType','string'); textData = data.event_narrative;
Set aside 10% of the documents at random for validation.
numDocuments = numel(textData); cvp = cvpartition(numDocuments,'HoldOut',0.1); textDataTrain = textData(training(cvp)); textDataValidation = textData(test(cvp));
Tokenize and preprocess the text data using the function
preprocessText which is listed at the end of this example.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Create a bag-of-words model from the training documents. Remove the words that do not appear more than two times in total. Remove any documents containing no words.
bag = bagOfWords(documentsTrain); bag = removeInfrequentWords(bag,2); bag = removeEmptyDocuments(bag);
For each of the LDA solvers, fit an LDA model with 60 topics. To distinguish the solvers when plotting the results on the same axes, specify different line properties for each solver.
numTopics = 60; solvers = ["cgs" "avb" "cvb0" "savb"]; lineSpecs = ["+-" "*-" "x-" "o-"];
For the validation data, create a bag-of-words model from the validation documents.
validationData = bagOfWords(documentsValidation);
For each of the LDA solvers, fit the model, set the initial topic concentration to 1, and specify not to fit the topic concentration parameter. Using the data in the
FitInfo property of the fitted LDA models, plot the validation perplexity and the time elapsed. Plot the time elapsed in a logarithmic scale. This can take up to an hour to run.
The code for removing NaNs is necessary because of a quirk of the stochastic solver
'savb'. For this solver, the function evaluates the validation perplexity after each pass of the data. The function does not evaluate the validation perplexity for each iteration (mini-batch) and reports NaNs in the
FitInfo property. To plot the validation perplexity, remove the NaNs from the reported values.
figure for i = 1:numel(solvers) solver = solvers(i); lineSpec = lineSpecs(i); mdl = fitlda(bag,numTopics, ... 'Solver',solver, ... 'InitialTopicConcentration',1, ... 'FitTopicConcentration',false, ... 'ValidationData',validationData, ... 'Verbose',0); history = mdl.FitInfo.History; timeElapsed = history.TimeSinceStart; validationPerplexity = history.ValidationPerplexity; % Remove NaNs. idx = isnan(validationPerplexity); timeElapsed(idx) = ; validationPerplexity(idx) = ; semilogx(timeElapsed,validationPerplexity,lineSpec) hold on end hold off xlabel("Time Elapsed (s)") ylabel("Validation Perplexity") legend(solvers)
For the stochastic solver
"savb", the function, by default, passes through the training data once. To process more passes of the data, set
'DataPassLimit' to a larger value (the default value is 1). For the batch solvers (
"cvb0"), to reduce the number of iterations used to fit the models, set the
'IterationLimit' option to a lower value (the default value is 100).
A lower validation perplexity suggests a better fit. Usually, the solvers
"cgs" converge quickly to a good fit. The solver
"cvb0" might converge to a better fit, but it can take much longer to converge.
FitInfo property, the
fitlda function estimates the validation perplexity from the document probabilities at the maximum likelihood estimates of the per-document topic probabilities. This is usually quicker to compute, but can be less accurate than other methods. Alternatively, calculate the validation perplexity using the
logp function. This function calculates more accurate values but can take longer to run. For an example showing how to compute the perplexity using
logp, see Calculate Document Log-Probabilities from Word Count Matrix.
preprocessText performs the following steps:
Tokenize the text using
Lemmatize the words using
Erase punctuation using
Remove a list of stop words (such as "and", "of", and "the") using
Remove words with 2 or fewer characters using
Remove words with 15 or more characters using
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Lemmatize the words. documents = addPartOfSpeechDetails(documents); documents = normalizeWords(documents,'Style','lemma'); % Erase punctuation. documents = erasePunctuation(documents); % Remove a list of stop words. documents = removeStopWords(documents); % Remove words with 2 or fewer characters, and words with 15 or greater % characters. documents = removeShortWords(documents,2); documents = removeLongWords(documents,15); end