Main Content

Detect Anomalies in ECG Time-Series Data Using Wavelet Scattering and LSTM Autoencoder in Simulink

This example shows how to use wavelet scattering and a deep learning network within a Simulink® model to detect anomalies in ECG signals.

The example shows how to extract robust features from ECG signals using wavelet scattering, pass them through a long short-term memory (LSTM)-based encoder-decoder network that attempts to reconstruct the signal, and use the reconstruction error to detect anomalies in the signal. For information on how to detect anomalies in ECG time series data using deepSignalAnomalyDetector object in MATLAB®, see Detect Anomalies in Machinery Using LSTM Autoencoder.

Data Description

This example uses ECG data obtained from the Sudden Cardiac Death Holter Database [1]. This database contains a collection of long-term ECG recordings of patients who experienced sudden cardiac death during the recordings. The dataset includes data from patients who sustained different types of ventricular tachyarrhythmia such as atrial fibrillation, ventricular tachycardia etc.

This example uses ECG data from only one of the patients with a history of ventricular tachycardia and attempts to detect anomalies in the ECG data of the patient caused by ventricular tachycardia. The ECG signal has a sampling rate of 250 Hz.

Download and Prepare Data

Download the data from https://ssd.mathworks.com/supportfiles/SPT/data/PhysionetSDDB.zip using the downloadSupportFile function. The data set contains 320 seconds of ECG data of the patient and the downloaded file contains two timetables. The timetable X contains the ECG signal of the patient. Timetable Y contains the annotated labels that indicate whether each sample of the ECG signal is normal. You use the labels only to visualize the dataset.

datasetZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/PhysionetSDDB.zip');
dataFolder = fullfile(tempdir,'PhysionetSDDB');
unzip(datasetZipFile,dataFolder);
ds2 = load(fullfile(dataFolder,"sddb49.mat"));
ecgSignals = ds2.X;
ecgLabels = ds2.y;

Visualize the ECG data by overlaying the annotated anomalies. Zoom in on the plot to observe the region of the anomaly better.

Notice that the ECG data contains gradual changes in the baseline of the ECG signal. These gradual changes are known as baseline drifts. Baseline drifts are caused by factors such as respiration, movement artifacts, changes in skin impedence etc and happens often in normal ECG data as well. Anomaly detection in ECG signals is challenging because these changes in baseline level can be misclassified as anomalies.

figure;
yyaxis left
plot(ecgSignals.Time,ecgSignals.Variables);
title("ECG Signal");
ylabel("ECG Amplitude")
yyaxis right;
plot(ecgSignals.Time,ecgLabels.anomaly)
xlabel("Time (s)")
legend(["Signal" "Label"],Location="southwest");
ylabel("Annotation")
yticks([0 1]);yticklabels({'Normal','Anomaly'})
ylim([-0.2,1.2]);

Split the data set into training and testing sets. A common approach to choosing training data is to use a segment of the signal where it is evident that there are no anomalies. In many situations, the beginning of a recording is usually normal, such as in this ECG signal. Choose the first 200 seconds of the recording to train the model with purely normal data. Use the rest of the recording to test the performance of the anomaly detector. The training data contains segments with baseline drift. Ideally, the detector learns and adapts to the baseline drift and considers it normal.

fs = 250;
idxTrain = 1:200*fs;
idxTest = idxTrain(end)+1:height(ecgSignals);
dataTrain = ecgSignals(idxTrain,:);
dataTest = ecgSignals(idxTest,:);

Normalize the training data using the normalize function and obtain the mean and standard deviation from the output arguments C and S respectively. We will use this value of mean and standard deviation to normalize the input data in Simulink when running the model.

[dataProcessedTrain,C,S] = normalize(dataTrain);
meanVal = C.DISTORTEDsddb49
meanVal = -30.4565
stdVal = S.DISTORTEDsddb49
stdVal = 197.2967

Wavelet Scattering Network

Wavelet scattering is a powerful tool for signal analysis that captures both low-frequency and high-frequency information. An input signal is convolved with a series of wavelet filters at multiple scales and positions, and the resulting coefficients are passed through nonlinearities and averaging to produce low-variance representations of time series. This process enables the extraction of robust and discriminative features insensitive to shifts in the input signal. For more information on feature extraction using wavelet scattering, see Wavelet Scattering (Wavelet Toolbox).

By decomposing the ECG signal into different frequency bands using wavelet transforms, wavelet scattering can effectively separate the baseline drift from the underlying cardiac activity. The wavelet scattering coefficients provide a representation of the signal that is more robust to baseline drift while remaining sensitive to other anomalies in the signal. By extracting features from these coefficients, it becomes possible to analyze the ECG signal while mitigating the effects of baseline drift.

Split the 200 seconds long training data into 200 sequences of length 1 second. Use an invariance scale of 0.5 sec. To have a sufficient number of scattering coefficients per time window to average, set OversamplingFactor to 2 to produce a four-fold increase in the number of scattering coefficients for each path with respect to the critically downsampled value. With these settings, you obtain 94 scattering paths with 32 scattering coefficients for each path.

N = 250; 
sn = waveletScattering(SignalLength=N,SamplingFrequency=fs,...
    InvarianceScale=0.5,OversamplingFactor=2);

[spaths,npaths] = paths(sn);
npaths=sum(npaths)
npaths = 94
ncoeffs = numCoefficients(sn)
ncoeffs = 32

Exclude the zeroth-order scattering coefficients and convert the features to cell arrays. To improve the robustness of the features to baseline drift and better detect higher frequency anomalies, consider removing additional lower-order coefficients.

Xtrain = reshape(dataProcessedTrain.DISTORTEDsddb49,N,[]);
trainfeat = featureMatrix(sn,Xtrain);
trainfeat = trainfeat(2:end,:,:);
trainfeatcell = squeeze(num2cell(trainfeat,[1,2]));

LSTM Autoencoder

Autoencoders are used to detect anomalies in a signal. The autoencoder is trained on features extracted from data without anomalies. As a result, the learned network weights minimize the reconstruction error for features extracted from ECG data without anomalies. The statistics of the reconstruction error for the training data can be used to select the threshold in the anomaly detection block that determines the detection performance of the autoencoder. The detection block declares the presence of an anomaly when it encounters a reconstruction error above threshold. This example uses an LSTM Autoencoder and root-mean-square error (RMSE) as the reconstruction error metric. For more information on detecting anomalies in signals using LSTM autoencoder, refer to the deepSignalAnomalyDetectorLSTM object.

Create the autoencoder network. The network has one LSTM layer with 64 hidden units followed by ReLU layer and a dropout layer in the encoder. A repeat vector layer is added at the end of the encoder sequences of the same length as the input sequences. The decoder consists of a LSTM layer and a ReLU layer, followed by a fully connected layer. Note that we can use a smaller network here because the extracted features are robust and insensitive to baseline drifts.

pLayers = [sequenceInputLayer(npaths-1,Name="Input Layer",MinLength=ncoeffs),...
            lstmLayer(64,OutputMode="last"),...
            reluLayer,...
            dropoutLayer(0.2),...            
            repeatVectorLayer(ncoeffs),...
            lstmLayer(64),...
            reluLayer,...
            fullyConnectedLayer(npaths-1),...
            regressionLayer]; 

Specify the hyperparameters. Use Adam optimization and a mini-batch size of 50. Set the maximum number of epochs to 250.

options = trainingOptions('adam',...
    MaxEpochs=250,...
    MiniBatchSize=50,...
    Plots='training-progress',...
    Verbose=false);

Train the network.

netECGAnomaly = trainNetwork(trainfeatcell,trainfeatcell,pLayers,options);

Obtain Threshold

The statistics of reconstruction error computed using the training data can be used to determine the threshold which is to be used for detecting anomalies in the ECG data during inference. Compute the RMSE between the input features and reconstructed features of the training data.

reconstrTrainFeatCell = predict(netECGAnomaly,trainfeatcell);
reconstrTrainFeat = cell2mat(permute(reconstrTrainFeatCell,[2 3 1]));
rmseTrain = sqrt(mean((reconstrTrainFeat-trainfeat).^2,[1 2]));

Obtain the maximum value of RMSE for the training data

maxrmseTrain = max(rmseTrain)
maxrmseTrain = single
    0.0116

For inference, to account for the baseline drift, choose a threshold which is slightly higher than the maximum RMSE value observed in the training data set.

threshold = 1.25*maxrmseTrain
threshold = single
    0.0146

Open the Simulink model

Open the attached Simulink model. Use this model to extract wavelet scattering features and detect anomalies in ECG data using a network that has been trained using the steps mentioned in this example.

open_system('ECGAnomalyDetection.slx');

ModelImage1.PNG

Read the test ECG data one sample at a time and normalize it using the Normalize subsystem block. Provide the mean and standard deviation values calculated from the training data on the dialog box of the Normalize block. Pass the normalized ECG data to the Wavelet Scattering block. Use the same parameters which were used to extract features from the training data on the Wavelet Scattering block, and provide the extracted features to the Predict block. The Predict block uses a network that has been trained using the steps mentioned in the 'LSTM Autoencoder' section. The input signal features and the reconstructed features are provided to the Post-processing block. Enter the threshold value you calculated in the 'Obtain Threshold' section in the dialog box of the Post-processing subsystem block.

ModelImage2.PNG

The Post-processing block calculates the RMSE between input and reconstructed features and compares this value to the threshold to determine whether the current frame of input is an anomaly. The Repeat block repeats the same decision on anomaly for all the 250 input samples in the current frame.

Delay the input signal to account for the buffering in the Wavelet Scattering block. Visualize the delayed signal and the logical output from the Post-processing block using Time Scope.

Simulate the model.

sim('ECGAnomalyDetection.slx');

On the Time Scope, you can see the overlay of input ECG signal and the logical decision of anomaly detection. The portion of input signal corresponding to a decision of 0 represents normal signal and the portion of signal corresponding to a decision of 1 represents anomalous region.

TimescopeImage.PNG

Analysis

Extract features from the test data, reconstruct the features, and compute the RMSE between input and reconstructed features.

dataProcessedTest = normalize(dataTest,"center",C,"scale",S);
Xtest = reshape(dataProcessedTest.DISTORTEDsddb49,N,[]);
testfeat = featureMatrix(sn,Xtest);
testfeat = testfeat(2:end,:,:);
testfeatcell = squeeze(num2cell(testfeat,[1,2]));

reconstrTestFeatCell = predict(netECGAnomaly,testfeatcell);
reconstrTestFeat = cell2mat(permute(reconstrTestFeatCell,[2 3 1]));
rmseTest = sqrt(mean((reconstrTestFeat-testfeat).^2,[1 2]));

Repeat the value of RMSE obtained for each frame of the input data for all the samples in the frame. Plot the threshold and RMSE values for the train data, test data, and overlay the annotation labels on the plot.

Observe from the low RMSE values that the network manages to reconstruct the data well. For region of the data containing anomalies, the network gives an error which is significantly higher than the error for normal signal, thereby helping in robust detection of anomalies. Notice that the regions corresponding to baseline drifts in the signal have relatively higher RMSE compared to other regions. However, this RMSE value is significantly less compared to the RMSE value in the regions where anomalies are seen.

rmseTestExpanded = repelem(rmseTest,1,1,N);
rmseTrainExpanded = repelem(rmseTrain,1,1,N);
figure;
yyaxis left;
plot(dataTrain.Time,rmseTrainExpanded(:));
hold on;
plot(dataTest.Time,rmseTestExpanded(:),'m-');
plot(ecgSignals.Time,threshold*ones(height(ecgSignals.Time),1),'g-',LineWidth=2);
xlabel('Time (s)');
ylabel ('RMSE');
yyaxis right
plot(ecgSignals.Time,ecgLabels.anomaly,LineWidth=1);
yticks([0 1]);yticklabels({'Normal','Anomaly'})
ylim([-0.2,1.2]);
ylabel("Annotation")
legend(["Training data" "Test data" " Threshold" "Ground truth label"],Location="northwest");

References

[1] Greenwald, Scott David. "Development and analysis of a ventricular fibrillation detector." (M.S. thesis, MIT Dept. of Electrical Engineering and Computer Science, 1986).

[2] Nawaz M, Ahmed J (2022) Cloud-based healthcare framework for real-time anomaly detection and classification of 1-D ECG signals. PLoS ONE 17(12): e0279305. https://doi.org/10.1371/journal.pone.0279305

See Also

| | |

Related Topics