Hi everyone, have a question about PINN.
69 views (last 30 days)
Show older comments
Hi everyone, recently I started to work on PINN. I tried to apply Lie symmetries enhanced PINN (sPINN). For this purpose, I tried to train the Kdv equation in [1] with the same conditions. The problem is, in theory, sPINN must give a better approximation than classic PINN, but in my code, I think something is missing. I changed the initial condition, equation, and added the symmetry condition of the code in [2].
[1] Enforcing continuous symmetries in physics-informed neural network for solving forward and inverse problems of partial differential equations
clear all;
close all;
clc;
%% Generate Training Data
% The first 25 is for the left/right boundary
numBoundaryConditionPoints = [128 128];
% This creates a row vector of 25 elements
x0BC1 = zeros(1,numBoundaryConditionPoints(1));
x0BC2 = ones(1,numBoundaryConditionPoints(2));
% This creates a vector of 25 equally spaced time points between 0 and 1.
t0BC1 = linspace(0,1,numBoundaryConditionPoints(1));
t0BC2 = linspace(0,1,numBoundaryConditionPoints(2));
% Calculate boundary
u0BC1 = 12*sech(-4*t0BC1).^2;
u0BC2 = 12*sech(1 - 4*t0BC2).^2;
numInitialConditionPoints = 256;
x0IC = linspace(0,1,numInitialConditionPoints);
t0IC = zeros(1,numInitialConditionPoints);
% Initial condition
u0IC = 12*sech(x0IC).^2;
% Group together the data for initial and boundary conditions.
X0 = [x0IC x0BC1 x0BC2];
T0 = [t0IC t0BC1 t0BC2];
U0 = [u0IC u0BC1 u0BC2];
% Defining the Number of Points
numInternalCollocationPoints = 10000;
% Generating Random Points in a Unit Square
points = rand(numInternalCollocationPoints,2);
dataX = 2*points(:,1);
dataT = points(:,2);
%% Define Neural Network Architecture
numBlocks = 8;
fcOutputSize = 20;
% This creates a fundamental block that will be repeated:
fcBlock = [
fullyConnectedLayer(fcOutputSize)
tanhLayer];
layers = [
featureInputLayer(2) % featureInputLayer(2): This is the input layer that accepts your features. The 2 corresponds to (x, t) coordinates.
repmat(fcBlock,[numBlocks 1]) % Creates a deep network: Input → Block 1 → Block 2 → ... → Block N
fullyConnectedLayer(1)
];
% Convert the layer array to a dlnetwork object.
net = dlnetwork(layers);
% Training a PINN can result in better accuracy when the learnable parameters have data type double.
% Convert the network learnables to double using the dlupdate function.
% Note that not all neural networks support learnables of type double, for example, networks that use GPU optimizations that rely on learnables with type single.
net = dlupdate(@double,net);
%% Define Model Loss Function
% This is the core loss function that makes Physics-Informed Neural Networks (PINNs) work! It's where the "physics" is enforced.
function [loss,gradients] = modelLoss(net,X,T,X0,T0,U0)
% Make predictions with the initial conditions.
XT = cat(1,X,T);
U = forward(net,XT);
% Calculate derivatives with respect to X and T.
X = stripdims(X);
T = stripdims(T);
U = stripdims(U);
Ux = dljacobian(U,X,1);
Ut = dljacobian(U,T,1);
% Calculate second-order derivatives with respect to X.
Uxx = dldivergence(Ux,X,1);
Uxxx = dldivergence(Uxx,X,1);
% Calculate mseF. (Physics Loss 1: The PDE)
f = Ut + U.*Ux + Uxxx;
mseF = mean(f.^2);
% Calculate mseG. (Physics Loss 2: Your new constraint)
g = 4.*Ux + Ut;% + (X.*U)./2;
mseG = mean(g.^2);
% Calculate mseU. (Data Loss: Initial + Boundary)
XT0 = cat(1,X0,T0);
U0Pred = forward(net,XT0);
mseU = l2loss(U0Pred,U0);
% Calculated loss
loss = mseF + mseU + mseG;
% Calculate gradients with respect to the learnable parameters.
gradients = dlgradient(loss,net.Learnables);
end
%% Specify the training options:
solverState = lbfgsState;
maxIterations = 500;
gradientTolerance = 1e-5;
stepTolerance = 1e-5;
%% Train Neural Network
% Convert the training data to dlarray objects.
% Specify that the inputs X and T have format "BC" (batch, channel) and that the initial conditions have format "CB" (channel, batch).
X = dlarray(dataX,"BC");
T = dlarray(dataT,"BC");
X0 = dlarray(X0,"CB");
T0 = dlarray(T0,"CB");
U0 = dlarray(U0,"CB");
% Accelerate the loss function using the dlaccelerate function.
accfun = dlaccelerate(@modelLoss);
% Create a function handle containing the loss function for the L-BFGS update step.
% In order to evaluate the dlgradient function inside the modelLoss function using automatic differentiation, use the dlfeval function.
lossFcn = @(net) dlfeval(accfun,net,X,T,X0,T0,U0);
% Initialize the TrainingProgressMonitor object.
% At each iteration, plot the loss and monitor the norm of the gradients and steps.
% Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor( ...
Metrics="TrainingLoss", ...
Info=["Iteration" "GradientsNorm" "StepNorm"], ...
XLabel="Iteration");
iteration = 0;
while iteration < maxIterations && ~monitor.Stop
iteration = iteration + 1;
[net, solverState] = lbfgsupdate(net,lossFcn,solverState);
updateInfo(monitor, ...
Iteration=iteration, ...
GradientsNorm=solverState.GradientsNorm, ...
StepNorm=solverState.StepNorm);
recordMetrics(monitor,iteration,TrainingLoss=solverState.Loss);
monitor.Progress = 100*iteration/maxIterations;
if solverState.GradientsNorm < gradientTolerance || ...
solverState.StepNorm < stepTolerance || ...
solverState.LineSearchStatus == "failed"
break
end
end
%% Exact solution for your specific case
function U = solveEq(X,T)
U = 12*sech(X-4.*T).^2;
end
%% Evaluate Model Accuracy
tTest = [0.25 0.5 0.75 1];
numObservationsTest = numel(tTest);
szXTest = 1001;
XTest = linspace(0,1,szXTest);
XTest = dlarray(XTest,"CB");
% Test the model.
UPred = zeros(numObservationsTest,szXTest);
UTest = zeros(numObservationsTest,szXTest);
for i = 1:numObservationsTest
t = tTest(i);
TTest = repmat(t,[1 szXTest]);
TTest = dlarray(TTest,"CB");
XTTest = cat(1,XTest,TTest);
UPred(i,:) = forward(net,XTTest);
UTest(i,:) = solveEq(extractdata(XTest),t);
end
err = norm(UPred - UTest) / norm(UTest);
fprintf('Relative error: %e\n', err);
figure
tiledlayout("flow")
for i = 1:numel(tTest)
nexttile
plot(XTest,UPred(i,:),"-",LineWidth=2);
hold on
plot(XTest, UTest(i,:),"--",LineWidth=2)
hold off
ylim([0, 13])
xlabel("x")
ylabel("u(x," + t + ")")
end
legend(["Prediction" "Target"])
%% Create Density Plots with Rainbow Color Range
% Create a finer grid for density plots
xGrid = linspace(0, 1, 200);
tGrid = linspace(0, 1, 100);
[XGrid, TGrid] = meshgrid(xGrid, tGrid);
% Create predicted solution matrix
UPredDensity = zeros(length(tGrid), length(xGrid));
UTestDensity = zeros(length(tGrid), length(xGrid));
% Generate predictions for each point in the grid
for i = 1:length(tGrid)
for j = 1:length(xGrid)
% Predicted solution
XPoint = dlarray(xGrid(j), "CB");
TPoint = dlarray(tGrid(i), "CB");
XTPoint = cat(1, XPoint, TPoint);
UPredDensity(i,j) = extractdata(forward(net, XTPoint));
% Exact solution
UTestDensity(i,j) = solveEq(XGrid(j), TGrid(i));
end
end
% Create density plots
figure('Position', [100, 100, 1200, 500]);
% Predicted solution density plot
subplot(1,2,1);
imagesc(xGrid, tGrid, UPredDensity);
colormap(jet); % Rainbow colormap
colorbar;
axis xy; % Correct orientation (time increasing downward)
xlabel('x');
ylabel('t');
title('Predicted Solution Density');
set(gca, 'FontSize', 12);
% Exact solution density plot
subplot(1,2,2);
imagesc(xGrid, tGrid, UTestDensity);
colormap(jet); % Rainbow colormap
colorbar;
axis xy; % Correct orientation (time increasing downward)
xlabel('x');
ylabel('t');
title('Exact Solution Density');
set(gca, 'FontSize', 12);
% Add a main title
sgtitle('Solution Comparison - Density Plots', 'FontSize', 14, 'FontWeight', 'bold');
%% Line plots for specific time points (your original visualization)
figure('Position', [100, 100, 1000, 800]);
tiledlayout("flow")
for i = 1:numel(tTest)
nexttile
plot(XTest,UPred(i,:),"-",LineWidth=2);
hold on
plot(XTest, UTest(i,:),"--",LineWidth=2)
hold off
ylim([0, 13])
xlabel("x")
ylabel("u(x," + tTest(i) + ")")
title(sprintf('t = %.2f', tTest(i)))
end
legend(["Prediction" "Target"])
sgtitle('Solution Comparison - Time Slices', 'FontSize', 14, 'FontWeight', 'bold');
2 Comments
Anumeha
on 9 Dec 2025 at 5:48
Hi Ughur,
To implement sPINN, you need to add ISC (Invariant Surface Condition) constraints to your loss function. These constraints are derived from the Lie or non-classical symmetries of your PDE and ensure the neural network's solution respects those symmetries. Specifically, compute the ISC residual at your collocation points and include its mean squared error in your total loss. Hope this works for you.
Answers (0)
See Also
Categories
Find more on Custom Training Loops in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!