Code covered by the BSD License  

Highlights from
RungeKuttaTimeIntegrationExplorer

image thumbnail

RungeKuttaTimeIntegrationExplorer

by

 

24 Jul 2013 (Updated )

Launches an interactive GUI for exploring Runge Kutta time integration schemes

RungeKuttaTimeIntegrationExplorer
function RungeKuttaTimeIntegrationExplorer
% GUI for exploring different Runge-Kutta time integration schemes.
% Convergence rates are plotted and stability plots are displayed for
% forward difference, Runge-Kutta 2, Runge-Kutta 4 and Galerkin implicit
% Runge-Kutta 2.  
%   Copyright 2013 The MathWorks, Inc.


%% Initialize the parameters
fun{1} = @(t,u) (-2*u);
% fun{2} = @(x) sin(2*pi*x);
% fun{3} = @(x) sin(4*pi*x);
% fun{4} = @(x) exp(x) + 0.05*sin(8*pi*x);
% fun{5} = @(x) (x <= 1/3).*sin(pi*x) + (x > 1/3).*0.5.*sin(pi*x);
handles.fun = fun;

dfun{1} = @(t,u) (-2);
% dfun{2} = @(x) 2*pi*cos(2*pi*x);
% dfun{3} = @(x) 4*pi*cos(4*pi*x);
% dfun{4} = @(x) exp(x) + 0.2*(2*pi*cos(8*pi*x));
% dfun{5} = @(x) (x <= 1/3).*pi*cos(pi*x) + (x > 1/3).*0.5.*pi*cos(pi*x);
handles.dfun = dfun;

fun_exact{1} = @(t) exp(-2*t);
handles.fun_exact = fun_exact;

handles.FunctionNames = {
    'du/dt = -2 * u'
%     'f(x) = sin(2 pi x)'
%     'f(x) = sin(4 pi x)'
%     'f(x) = exp(x) + 0.05*sin(8 pi x)'
%     'f(x) = (x <= 1/3).*sin(pi*x) + (x > 1/3).*0.5.*sin(pi x)'
    };
handles.TimeFinal = 1;
handles.u0        = 1;

handles.xdom = [-0.6 0.6];

ylim = [
    1e-5 1e0
    1e-2 1e1
    1e-1 1e2
    1e-2 1e1
    1e-4 1e1
    ];
handles.ylim = ylim;

handles.Methods = {
    'Forward Difference'
    'Runge-Kutta 2'
    'Runge-Kutta 4'
    'GL Implicit Runge-Kutta 2'
    };

handles.VisibleFlag = ones(length(handles.Methods),1);

handles.nnint = [0:5];

handles.StabilityXdom = [-3 3];
handles.StabilityYdom = [-3 3];

handles.Color = [
    0.00 0.00 1.00
    0.00 0.50 0.00
    1.00 0.00 0.00
    0.00 0.75 0.75
    0.75 0.00 0.75];
handles.StabilityColor = [
    0.729412 0.831373 0.956863
    0.7569    0.8667    0.7765
    0.9255    0.8392    0.8392
    0.7255    0.9255    0.9216
    0.9882    0.7569    0.9882
    ];
handles.Markers = [
    'o'
    'x'
    '+'
    '*'
    's'
    ];

handles.LineWidth            = 1.5;
handles.MarkerSizeError      = 20;
handles.MarkerSizeCurrentn   = 40;
handles.MarkerSizeMethodPnts = 10;
handles.FontSize             = 12;

%% build up the GUI
fh = figure('Visible','on','Name','RungeKuttaTimeIntegrationExplorer');
set(fh,'MenuBar','none');
pos_orig = get(fh,'Position');
Width = 750; Height = 750;
set(fh,'Position',[pos_orig(1), pos_orig(2)+pos_orig(4)-Height, Width, Height]);
set(fh,'units','normalized');

% add invisible axis to hold text in case we want latex
handles.textaxis = axes('parent',fh,...
    'units','normalized',...
    'position',[0 0 1 1],...
    'visible','off');

bottom = 0.575;
width  = 0.425;
horiz_space = 0.05;
height = 0.4;
% add axis for plotting solution
handles.PlotAxes = axes('Parent',fh,...
    'Position',[horiz_space, bottom, width, height]);
handles = AddLinesToPlotAxes(handles);

bottom = 0.075;
% add axis for plotting error convergence
handles.ErrorAxes = axes('Parent',fh,...
    'Position',[horiz_space, bottom, width, height]);
handles = AddLinesToErrorAxes(handles);

% add axis for plotting stability
handles.StabilityAxes = axes('Parent',fh,...
    'Position',[2*horiz_space+width, bottom, width, height]);
handles = AddStabilityPlot(handles);

% add the slider for N
nmin = min(handles.nnint);
nmax = max(handles.nnint);
handles.NControlSlider = uicontrol(fh,'Style','slider',...
    'Max',nmax,'Min',nmin,...
    'Value',2,...
    'SliderStep',1/(nmax-nmin)*[1 1],...
    'FontSize',handles.FontSize,...
    'units','normalized',...
    'Position',[2*horiz_space+width+0.025,0.525, .35, .05],...
    'CallBack',@UpdatePlotsFromNControlSlider);
axes(handles.textaxis)
handles.NControlText = text(2*horiz_space+width+0.0+width/2,.5,...
    'N = 2^1 = 2, \Delta t = 1/2',...
    'unit','normalized',...
    'HorizontalAlignment','center',...
    'FontSize',handles.FontSize);

% add function selector
handles.FunctionSelector = uicontrol(fh,'Style','popupmenu',...
    'String',handles.FunctionNames,...
    'Value',1,...
    'FontSize',handles.FontSize,...
    'units','normalized',...
    'Position',[2*horiz_space+width+0.025,0.875, .35, .1],...
    'CallBack',@GetErrorData);

% add MethodPanel
handles.MethodPanel = uibuttongroup('Parent',fh,...
    'Title','Finite Differencing Methods',...
    'Position',[2*horiz_space+width+0.025,0.675, .35, .25]);

controlleft = 0.1;
controlwidth = 0.9;
controlheight = 0.1;

for i=1:length(handles.Methods)
    handles.MethodBoxes(i) = uicontrol(handles.MethodPanel,'Style','checkbox',...
        'String',handles.Methods{i},...
        'Value',1,...
        'FontSize',handles.FontSize,...
        'units','normalized',...
        'ForegroundColor', handles.Color(i,:),...
        'Position',[controlleft, 1-.2*i, controlwidth, controlheight],...
        'CallBack',@UpdatePlotsFromMethods);
end

% add checkbox for rate triangle
handles.RateTriangle = uicontrol(fh,'Style','checkbox',...
        'String','Display Convergence Triangles',...
        'Value',0,...
        'FontSize',handles.FontSize,...
        'units','normalized',...
        'Position',[2*horiz_space+width+0.025,0.6, .35, .05],...
        'CallBack',@UpdatePlots);
    
    

%% store handles as appdata in fh
setappdata(fh,'handles',handles);

%% make initial plot
GetErrorData(handles.FunctionSelector);

end

function handles = AddLinesToPlotAxes(handles)
PlotAxes = handles.PlotAxes;
% set up the x limit
set(PlotAxes,'XLim',[0 handles.TimeFinal]);
xlabel('x');
ylabel('f(x)');

% retrieve the standard colors
Color        = handles.Color;
LW           = handles.LineWidth;
MSmethodpnts = handles.MarkerSizeMethodPnts;
Markers      = handles.Markers;


% build necessary line handles
for i=1:length(handles.Methods)
    hMethodLine(i) = line;
    set(hMethodLine(i),'xdata',[],'ydata',[]);
    set(hMethodLine(i),'LineStyle','none')
    set(hMethodLine(i),'Color',Color(i,:))
    set(hMethodLine(i),'LineWidth',LW)
    set(hMethodLine(i),'Marker',Markers(i,:));
    set(hMethodLine(i),'MarkerSize',MSmethodpnts)
    
    hMethodPnts(i) = line;
    set(hMethodPnts(i),'xdata',[],'ydata',[]);
    set(hMethodPnts(i),'LineStyle','none')
    set(hMethodPnts(i),'Color',Color(i,:))
    set(hMethodPnts(i),'Marker','o')
    set(hMethodPnts(i),'MarkerSize',MSmethodpnts)
    
end

% create function line handle
hFunction  = line;
set(hFunction ,'xdata',[],'ydata',[]);
set(hFunction ,'LineStyle','--');
set(hFunction ,'Color',[0 0 0]);
set(hFunction ,'LineWidth',LW);

% store handles for later use
handles.hMethodLines      = hMethodLine;
handles.hMethodPnts       = hMethodPnts;
handles.hFunction         = hFunction;
end

function handles = AddLinesToErrorAxes(handles)
ErrorAxes = handles.ErrorAxes;
% change yscale to log
set(ErrorAxes,'YScale','log')
% set x-limits
set(ErrorAxes,'XLim',[handles.nnint(1)-1 handles.nnint(end)+1]);
% set axis labels
xlabel('number of refinements');
ylabel('error at t = 1');

% retrive default colors
Color   = handles.Color;
LW      = handles.LineWidth;
MSerror = handles.MarkerSizeError;
MScurn  = handles.MarkerSizeCurrentn;

% build necessary line handles
for i=1:length(handles.Methods)
    % set up handles for the Error Lines
    hErrorLine(i) = line;
    set(hErrorLine(i),'xdata',[],'ydata',[]);
    set(hErrorLine(i),'Marker','.','LineStyle','-','Visible','off')
    set(hErrorLine(i),'Color',Color(i,:));
    set(hErrorLine(i),'LineWidth',LW);
    set(hErrorLine(i),'MarkerSize',MSerror);
    
    % set up handles for the error convergence rate triangles
    hErrorRateTriangle(i) = line;
    set(hErrorRateTriangle(i),'xdata',[],'ydata',[]);
    set(hErrorRateTriangle(i),'LineStyle','--','Visible','off')
    set(hErrorRateTriangle(i),'Color',Color(i,:));
    set(hErrorRateTriangle(i),'LineWidth',LW);
    
    % set up handles for the text of the error convergence rate
    hErrorRateTriangleText(i) = text;
    set(hErrorRateTriangleText(i),'horizontalalignment','left');
    set(hErrorRateTriangleText(i),'Color',Color(i,:));
    
    % set up handles for the current n data point
    hErrorCurentn(i) = line;
    set(hErrorCurentn(i),'xdata',[],'ydata',[]);
    set(hErrorCurentn(i),'Marker','.','Visible','off')
    set(hErrorCurentn(i),'Color',Color(i,:));
    set(hErrorCurentn(i),'MarkerSize',MScurn);
end

% store the handles for later use
handles.hErrorLines = hErrorLine;
handles.hErrorRateTriangles = hErrorRateTriangle;
handles.hErrorRateTrianglesText = hErrorRateTriangleText;
handles.hErrorCurrentn = hErrorCurentn;

end

function UpdatePlotsFromNControlSlider(hObject, ~)
% get handles
fh = get(hObject,'Parent');
handles = getappdata(fh,'handles');
n = get(handles.NControlSlider,'Value');
% make sure n stays an integer
n = round(n);
set(handles.NControlSlider,'Value',n);
str = sprintf('N = 2^%d = %d, \\Delta t = 1/%d',n,2^n,2^n);
set(handles.NControlText,'String',str);
UpdatePlots(hObject);
end

function UpdatePlotsFromMethods(hObject,~)
MethodBox = get(hObject,'Parent');
UpdatePlots(MethodBox);
end

function GetErrorData(hObject, ~)
% get handles
fh = get(hObject,'Parent');
handles = getappdata(fh,'handles');
% compute the error
handles = ComputeError(handles);
setappdata(fh,'handles',handles) ;
UpdatePlots(hObject);
end

function UpdatePlots(hObject, ~)
% get handles
fh = get(hObject,'Parent');
handles = getappdata(fh,'handles');

% update the visible flag
for i=1:length(handles.Methods)
    %get(handles.MethodBoxes(i),'Value')
    handles.VisibleFlag(i) = get(handles.MethodBoxes(i),'Value');
end

% update the plots
UpdateErrorPlot(handles)
UpdateFunctionPlot(handles)
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% based on new error data update the plot
function UpdateErrorPlot(handles)
onf = {'off','on'};

% get relavent handles
hErrorLine             = handles.hErrorLines;
hErrorRateTriangle     = handles.hErrorRateTriangles;
hErrorRateTriangleText = handles.hErrorRateTrianglesText;
hErrorCurentn          = handles.hErrorCurrentn;
n                      = get(handles.NControlSlider,'Value');
RTF                    = get(handles.RateTriangle,'Value');
%handles.RateTriangleFlag;
nnint                  = handles.nnint;
%get nnint index
nnint_index = find(nnint==n);

% loop over the methods
for t = 1:length(handles.Methods)
    % grab the error for this method
    error = handles.Error(t,:);
    % grab the visible flag
    VF = handles.VisibleFlag(t);

    
    % add the error line data
    set(hErrorLine(t),'xdata',nnint,'ydata',error);
    set(hErrorLine(t),'Visible',onf{VF+1})
    
    % add dot for current number of elements
    set(hErrorCurentn(t),'xdata',nnint(nnint_index),'ydata',error(nnint_index));
    set(hErrorCurentn(t),'Visible',onf{VF+1})
    
    % add the triangles if they are needed
    xx = nnint(end-1:end);
    yy = error(end-1:end);
    xtri = [xx(1) xx(1) xx(2) xx(1)];
    ytri = 0.9*[yy(1) yy(2) yy(2) yy(1)];
    set(hErrorRateTriangle(t),'xdata',xtri,'ydata',ytri);
    set(hErrorRateTriangle(t),'Visible',onf{(VF&&RTF)+1})
    
    % add the text label for the triangles if needed
    rate = log(yy(2)/yy(1))/log((2^xx(1))/(2^xx(2)));
    set(hErrorRateTriangleText(t),'Position',[1/0.98*xx(1),0.7*sqrt(yy(2)*yy(1))]);
    set(hErrorRateTriangleText(t),'String',sprintf('%.2f',rate));
    set(hErrorRateTriangleText(t),'Visible',onf{(VF&&RTF)+1})

end

%set(handles.ErrorAxes,'YLim',handles.ylim(get(handles.FunctionSelector,'Value'),:));

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% based on request update the function plot
function UpdateFunctionPlot(handles)
onf = {'off','on'};

% get relavent handles
hMethodLine = handles.hMethodLines;
hMethodPnts = handles.hMethodPnts;
hFunction   = handles.hFunction;
n           = get(handles.NControlSlider,'Value');
funindex = get(handles.FunctionSelector,'Value');
g        = handles.fun{funindex};
g_u      = handles.dfun{funindex};
g_exact  = handles.fun_exact{funindex};
tf       = handles.TimeFinal;
u0       = handles.u0;

% update exact function data
te = linspace(0,tf,1000);
ue = g_exact(te);
set(hFunction ,'xdata',te,'ydata',ue);

% loop over the methods and update the method lines and points
dt = (1/2)^n;
tt = 0:dt:tf;
for t = 1:length(handles.Methods)
    
    switch t
        case 1
            [foo,uu] = eulerforward(g,tt,u0);
        case 2
            [foo,uu] = rungekutta2(g,tt,u0);
        case 3
            [foo,uu] = rungekutta4(g,tt,u0);
        case 4
            [foo,uu] = glirk2(g,g_u,tt,u0);
        otherwise
            error();
    end
    
    % fill in the method line data
    set(hMethodLine(t),'xdata',tt,'ydata',uu);
    set(hMethodLine(t),'Visible',onf{handles.VisibleFlag(t)+1})
    
    % get the finite difference points
%     xp = GetxPnts(xloc,h,type);
%     yp = fun{funindex}(xp);
%     set(hMethodPnts(t),'xdata',xp,'ydata',yp);
%     set(hMethodPnts(t),'Visible',onf{handles.VisibleFlag(t)+1})
end
UpdateStabilityPlot(handles)
end

function handles = AddStabilityPlot(handles)
axes(handles.StabilityAxes);
% add some labels
xlabel('Re(\lambda\Deltat)');
ylabel('Im(\lambda\Deltat)');
% get some parameters 
Color   = handles.Color;
Scolor  = handles.StabilityColor;
LW      = handles.LineWidth;
xdom    = handles.StabilityXdom;
ydom    = handles.StabilityYdom;
axis([xdom ydom])

for t = 1:length(handles.Methods)
    switch t
        case 1
            %eulerforward
            type = 'ef';
        case 2
            %rungekutta2
            type = 'rk2';
        case 3
            %rungekutta4
            type = 'rk4';
        case 4
            %glirk2
            type = 'irk2';
        otherwise
            error();
    end
%    [hp,hl] = drawstabjmcontour(type,xdom,ydom,ax                   ,Color(t,:));
    [hp,hl] = drawstabcontour(type,xdom,ydom,handles.StabilityAxes,Color(t,:));
    set(hp,'FaceAlpha',0.2);
    set(hp,'FaceColor',Scolor(t,:));
    set(hp,'LineStyle','none');
    
    set(hl,'color',Color(t,:)); 
    set(hl,'linewidth',3);
    
    handles.StabilityPatches(t) = hp;
    handles.StabilityLines(t)   = hl;
   


end

end

function UpdateStabilityPlot(handles)
onf = {'off','on'};
StabilityPatches = handles.StabilityPatches;
StabilityLines   = handles.StabilityLines;
VisibleFlag      = handles.VisibleFlag;

for t = 1:length(handles.Methods)
    % fill in the method line data
    set(StabilityPatches(t),'Visible',onf{VisibleFlag(t)+1})
    set(StabilityLines(t)  ,'Visible',onf{VisibleFlag(t)+1})
end

end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Compute the error associated with current funindex for all methods
function handles = ComputeError(handles)
% grab the location we want to evaluate the derivative about
ddt = mat2cell(0.5.^(handles.nnint),1,ones(length(handles.nnint),1));
tf       = handles.TimeFinal;
u0       = handles.u0;
funindex = get(handles.FunctionSelector,'Value');
g        = handles.fun{funindex};
g_u      = handles.dfun{funindex};
g_exact  = handles.fun_exact{funindex};
te = linspace(0,tf,4000);
ue = g_exact(te);

for j = 1:length(ddt)
  tt{j} = 0:ddt{j}:tf;
  [foo,uuef{j}] = eulerforward(g,tt{j},u0);
  erref(j) = abs(ue(end) - uuef{j}(end));
  [foo,uurk2{j}] = rungekutta2(g,tt{j},u0);
  errrk2(j) = abs(ue(end) - uurk2{j}(end));
  [foo,uurk4{j}] = rungekutta4(g,tt{j},u0);
  errrk4(j) = abs(ue(end) - uurk4{j}(end));
  [foo,uuglirk2{j}] = glirk2(g,g_u,tt{j},u0);
  errglirk2(j) = abs(ue(end) - uuglirk2{j}(end));
end
error = [
    erref
    errrk2
    errrk4
    errglirk2
    ];
handles.Error   = error;


end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Compute the finite diffference derivative
function dfdx = FiniteDifferenceDerivative(fun,x,h,type)
    switch lower(type)
        case 'forward'
            dfdx = 1/h*(fun(x+h)-fun(x));
        case 'backward'
            dfdx = 1/h*(fun(x)-fun(x-h));
        case 'centered'
            dfdx = 1/(2*h)*(fun(x+h)-fun(x-h));
        otherwise
            error('unknown differentiation type');
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Compute the finite diffference derivative
function xp = GetxPnts(x,h,type)

    switch lower(type)
        case 'forward'
            xp = [x,x+h];
        case 'backward'
            xp = [x-h,x];
        case 'centered'
            xp = [x-h,x,x+h];
        otherwise
            error('unknown differentiation type');
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% based on request update the function plot
function ydom = GetDerivativeLineYDOM(xdom,x0,y0,dfdx)
b = y0-dfdx*x0;
ydom = dfdx.*xdom + b;
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% forward euler
function [t,u] = eulerforward(g,t,u0)
u = zeros(length(u0),length(t));
u(:,1) = u0;
for i = 2:length(t)
  dt = t(i) - t(i-1);
  u(:,i) = u(:,i-1) + dt*g(t(i-1),u(:,i-1));
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% runge-kutta 2
function [t,u] = rungekutta2(g,t,u0)

u = zeros(length(u0),length(t));
u(:,1) = u0;
for i = 2:length(t)
  dt = t(i) - t(i-1);
  v1 = u(:,i-1);
  G1 = g(t(:,i-1),v1);
  v2 = u(:,i-1) + 0.5*dt*G1;
  G2 = g(t(:,i-1)+0.5*dt,v2);
  u(:,i) = u(:,i-1) + dt*G2;
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% runge-kutta 4
function [t,u] = rungekutta4(g,t,u0)

u = zeros(length(u0),length(t));
u(:,1) = u0;
for i = 2:length(t)
  dt = t(i) - t(i-1);
  um1 = u(:,i-1);
  tm1 = t(:,i-1);
  v1 = um1;
  G1 = g(tm1,v1);
  v2 = um1 + 0.5*dt*G1;
  G2 = g(tm1+0.5*dt,v2);
  v3 = um1 + 0.5*dt*G2;
  G3 = g(tm1+0.5*dt,v3);
  v4 = um1 + dt*G3;
  G4 = g(tm1+dt,v4);
  u(:,i) = um1 + dt/6*(G1 + 2*G2 + 2*G3 + G4);
end
end





%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% galerkin implicit runge-kutta 2
function [t,u] = glirk2(g,g_u,t,u0)
A = [1/4, 1/4-sqrt(3)/6; 1/4+sqrt(3)/6, 1/4];
b = [1/2; 1/2];
c = [1/2-sqrt(3)/6, 1/2+sqrt(3)/6];

m = length(u0);
u = zeros(m,length(t));
u(:,1) = u0;
I = eye(m);
for i = 2:length(t)
  dt = t(i) - t(i-1);
  um1 = u(:,i-1);
  tm1 = t(:,i-1);
  i1 = 1:m;
  i2 = m+1:2*m;
  r = @(x) [x(i1) - um1 - A(1,1)*dt*g(tm1+c(1)*dt,x(i1)) - A(1,2)*dt*g(tm1+c(2)*dt,x(i2));
            x(i2) - um1 - A(2,1)*dt*g(tm1+c(1)*dt,x(i1)) - A(2,2)*dt*g(tm1+c(2)*dt,x(i2))];
  r_u = @(x) [I - A(1,1)*dt*g_u(tm1+c(1)*dt,x(i1)), - A(1,2)*dt*g_u(tm1+c(2)*dt,x(i2)) ;
               - A(2,1)*dt*g_u(tm1+c(1)*dt,x(i1)), I - A(2,2)*dt*g_u(tm1+c(2)*dt,x(i2))];
  v = newton(r,r_u,[um1;um1]);
  u(:,i) = um1 + dt*(b(1)*g(tm1+c(1)*dt,v(i1)) + b(2)*g(tm1+c(2)*dt,v(i2)));
end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% newton solver
function u = newton(r,r_u,u0)

maxiter = 20;
tol = 1e-10;

iter = 0;
u = u0;
while (iter < maxiter)
  res = r(u);
  if (norm(res) < tol)
    break;
  end
  du = -r_u(u)\r(u);
  u = u + du;
  iter = iter + 1;
end
if (iter == maxiter)
  error('newton did not converge');
end
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% compute and plot contours
function [hpatch, hline] = drawstabcontour(type,xdom,ydom,axes_handle,color)
if (nargin < 1), type = 'eb'; end;
if (nargin < 2), xdom = [-3,3]; end;
if (nargin < 3), ydom = [-3,3]; end;
if (nargin < 4), axes_handle = gca; end;
if (nargin < 5), color = [0 0 1]; end;

x = linspace(xdom(1),xdom(2),301);
y = linspace(ydom(1),ydom(2),301);
[xx,yy] = meshgrid(x,y);
zz = xx + 1i*yy;

switch type
 case {'eb','eulerbackward'}
  gamma = eulerbackamp(zz);
 case {'ef','eulerforward'}
  gamma = eulerforwardamp(zz);
 case 'rk2'
  gamma = rk2amp(zz);
 case 'rk4'
  gamma = rk4amp(zz);
 case 'irk2'
  gamma = irk2amp(zz);
 case 'ab2'
  gamma = ab2amp(zz);
 case 'ab3'
  gamma = ab3amp(zz);
 case 'am2'
  gamma = am2amp(zz);
 case {'cn','cranknicolson'}
  gamma = cnamp(zz);
 case 'bdf2'
  gamma = bdf2amp(zz);
 case 'bdf3'
  gamma = bdf3amp(zz);
 otherwise
  error('unknown integrator type')    
end

gammaabs = abs(gamma);

C = contourc(x,y,gammaabs,[1,1]);

xline = C(1,2:end);
yline = C(2,2:end);

xp = [xdom(1) xdom(2) xdom(2) xdom(1) xdom(1) xline];
yp = [ydom(1) ydom(1) ydom(2) ydom(2) ydom(1) yline];

% need to handle special case of cranknicolson and irk2
switch type
    case {'cn', 'cranknicolson', 'irk2'}
        xp = [0.0     xdom(2) xdom(2) 0.0     0.0    ];
        yp = [ydom(1) ydom(1) ydom(2) ydom(2) ydom(1)];
    otherwise
        % don't need to do anything
end
axes(axes_handle);
hpatch = patch(xp,yp,color,'FaceAlpha',.2,'LineStyle','none');
hline = line(xline,yline);
set(hline,'color',color,'linewidth',3);


end


function gamma = eulerbackamp(lamt)
gamma = 1./(1-lamt);
end

function gamma = eulerforwardamp(lamt)
gamma = 1+lamt;
end

function gamma = rk2amp(lamt)
v1 = 1;
v2 = 1 + 0.5.*lamt.*v1;
gamma = 1+lamt.*v2;
end

function gamma = rk4amp(lamt)
v1 = 1;
v2 = 1 + 0.5.*lamt.*v1;
v3 = 1 + 0.5.*lamt.*v2;
v4 = 1 + lamt.*v3;
gamma = 1 + lamt.*(1/6*v1 + 1/3*v2 + 1/3*v3 + 1/6*v4);
end

function gamma = irk2amp(lamt)
A = [1/4, 1/4-sqrt(3)/6; 1/4+sqrt(3)/6, 1/4];
b = [1/2, 1/2];
c = [1/2-sqrt(3)/6, 1/2+sqrt(3)/6];
d = 1-lamt.*(A(1,1)+A(2,2))+lamt.^2.*(A(1,1)*A(2,2)-A(2,1)*A(1,2));
v1 = 1./d.*(1-lamt*A(2,2) + lamt*A(1,2));
v2 = 1./d.*(lamt*A(2,1) + 1 - lamt*A(1,1));
gamma = 1 + lamt.*(b(1)*v1 + b(2)*v2);
end

function gamma = cnamp(lamt)
gamma = (1+0.5*lamt)./(1-0.5*lamt);
end

function gamma = ab2amp(lamt)
a = 1;
b = -(1+1.5*lamt);
c = 0.5*lamt;
disc = b.^2 - 4.*a.*c;
r1 = 1/(2*a).*(-b+sqrt(disc));
r2 = 1/(2*a).*(-b-sqrt(disc));
gamma = max(abs(r1),abs(r2));
end

function gamma = ab3amp(lamt)
a = zeros(4,1);
b = zeros(4,1);
a(1) = 1;
a(2) = -1;
b(1) = 0;
b(2) = 23/12;
b(3) = -4/3;
b(4) = 5/12;
gamma = zeros(size(lamt));
for k1 = 1:size(lamt,1)
    for k2 = 1:size(lamt,2)
        p = a - b*lamt(k1,k2);
        R = roots(p);
        gamma(k1,k2) = max(abs(R));
    end
end
end


function gamma = am2amp(lamt)
a = 1-5/12*lamt;
b = -(1+2/3*lamt);
c = 1/12*lamt;
disc = b.^2 - 4.*a.*c;
r1 = 1./(2.*a).*(-b+sqrt(disc));
r2 = 1./(2.*a).*(-b-sqrt(disc));
gamma = max(abs(r1),abs(r2));
end

function gamma = bdf2amp(lamt)
a = 1-2/3*lamt;
b = -4/3;
c = 1/3;
disc = b.^2 - 4.*a.*c;
r1 = 1./(2.*a).*(-b+sqrt(disc));
r2 = 1./(2.*a).*(-b-sqrt(disc));
gamma = max(abs(r1),abs(r2));
end

function gamma = bdf3amp(lamt)
a = 1-6/11*lamt;
b = -18/11;
c = 9/11;
d = -2/11;

gamma = zeros(size(lamt));
for k1 = 1:size(lamt,1)
  for k2 = 1:size(lamt,2)
    R = roots([a(k1,k2),b,c,d]);
    gamma(k1,k2) = max(abs(R));
  end
end
end




Contact us