image thumbnail

mexme - write MEX files in no time

by

 

30 Apr 2011 (Updated )

writes fully valid MEX .cpp files including mexFunction boilerplate based on numeric C/C++ snippet

mexme(cstring,inputargs,outputargs,opts)
function [cstring] = mexme(cstring,inputargs,outputargs,opts)
    %cstring = mexme(cstring,inputargs,outputargs,opts)
    %
    %A function that wraps a string of C code into a full-fledged mex file.
    %Takes care of writing the mexFunction boilerplate code for you.
    %Read TestMexMe.m for more details.
    %
    %Quick example:
    %
    %  cstring = ...
    %  ['for(mwSize i = 0; i < x_numel; i++) { ' char(10) ...
    %   '    y[i] = x[i]*x[i]; ' char(10) ...
    %   '}'];
    %
    %  inputarg = InputNum('x');
    %  outputarg = OutputNum('y','x_m, x_n');
    %  cfile = mexme(cstring,inputarg,outputarg) 
    %
    %Generates the following C code:
    % /* C file autogenerated by mexme.m */
    % #include <mex.h>
    % #include <math.h>
    % #include <matrix.h>
    %
    % void mexFunction( int nlhs, mxArray *plhs[],
    %                   int nrhs, const mxArray *prhs[] )
    % {
    % 
    % /*Input output boilerplate*/
    %     const mxArray *x_ptr = prhs[0];
    %     const mwSize   x_m = mxGetM(x_ptr);
    %     const mwSize   x_n = mxGetN(x_ptr);
    %     const mwSize   x_numel = mxGetNumberOfElements(x_ptr);
    %     const mwSize   x_ndims = mxGetNumberOfDimensions(x_ptr);
    %     const mwSize  *x_size = mxGetDimensions(x_ptr);
    %     const double  *x = (double *) mxGetData(x_ptr);
    % 
    % 
    %     mwSize y_dims[] = {x_m, x_n};
    %     plhs[0] = mxCreateNumericArray(2,y_dims,mxDOUBLE_CLASS,mxREAL);
    %     mxArray **y_ptr = &plhs[0];
    %     double   *y = (double *) mxGetData(*y_ptr);
    % 
    % 
    % 
    % /*Actual function */
    %     for(mwSize i = 0; i < x_numel; i++) { 
    %         y[i] = x[i]*x[i]; 
    %     }
    % 
    % }
    %
    %opts:
    %  opts.iocheck = true (default) | false
    %      if true, include code to verify that arguments fed to mex file
    %      are consistent with definitions (number of arguments, type,
    %      etc.)
    %  opts.extraincludes = "" (default) | valid C string
    %      extra code to include after #define but before mexFunction
    %See also InputNum, OutputNum, readfile, writefile
    
    %Defines formatting tabs (4 spaces)
    TAB = '    ';
    
    %type translator from Matlab to C
    tt = containers.Map;
    tt('uint64') = 'UINT64_T';
    tt('uint32') = 'UINT32_T';
    tt('uint16') = 'UINT16_T';
    tt('uint8')  = 'UINT8_T';
    
    tt('int64') = 'INT64_T';
    tt('int32') = 'INT32_T';
    tt('int16') = 'INT16_T';
    tt('int8')  = 'INT8_T';
    
    tt('single') = 'REAL32_T';
    tt('double') = 'double'; %A double is a double is a double
    
    if nargin < 4
        opts.iocheck = true;
        opts.extraincludes = '';
    end
    
    if ~isfield(opts,'iocheck')
        opts.iocheck = true;
    end
    
    if ~isfield(opts,'extraincludes')
        opts.extraincludes = '';
    end    
    
    if nargin < 3
        outputargs = [];
    end
    
    templates = defineTemplates();
    inputstr = '';
    %Parse input args
    for ii = 1:length(inputargs)
        arg = inputargs(ii);
        
        if opts.iocheck
            typecheck = parseTemplate(templates.typecheck,'ARGNUM',ii,'TYPE',arg.type,'MATTYPE',mattype(arg.type),'NAME',arg.name);
            cplxcheck = parseTemplate(templates.cplxcheck,'ARGNUM',ii,'TYPE',arg.type,'NAME',arg.name);
            scalarcheck = parseTemplate(templates.scalarcheck,'ARGNUM',ii,'TYPE',arg.type,'NAME',arg.name);
        else
            typecheck = '';
            cplxcheck = '';
            scalarcheck = '';
        end
        
        if arg.isscalar
            a = parseTemplate(templates.iscalar,'NAME',arg.name,'ARGNUM',ii-1,     'TYPE',tt(arg.type),'TYPECHECK',typecheck,'CPLXCHECK',cplxcheck,'SCALARCHECK',scalarcheck);
        elseif arg.isfull && arg.isreal
            a = parseTemplate(templates.irealfull,'NAME',arg.name,'ARGNUM',ii-1,   'TYPE',tt(arg.type),'TYPECHECK',typecheck,'CPLXCHECK',cplxcheck,'SCALARCHECK',scalarcheck);
        elseif arg.isfull && ~arg.isreal
            a = parseTemplate(templates.icomplexfull,'NAME',arg.name,'ARGNUM',ii-1,'TYPE',tt(arg.type),'TYPECHECK',typecheck,'CPLXCHECK',cplxcheck,'SCALARCHECK',scalarcheck);
        end
        
        if opts.iocheck && ~isempty(arg.extracheck)
            b = parseTemplate(templates.extracheck,'ARGNUM',ii,'NAME',arg.name,'EXTRA_CHECK',arg.extracheck,'ESCAPED_EXTRA_CHECK',escape(arg.extracheck));
            a = [a char(10) b];
        end
        
        inputstr = [inputstr a char(10)];
    end
    
    outputstr = '';
    for ii = 1:length(outputargs)
        arg = outputargs(ii);
        if arg.isfull && arg.isreal
            a = parseTemplate(templates.orealfull,'NAME',arg.name,'ARGNUM',ii-1,   'TYPE',tt(arg.type),'NDIMS',countDims(arg.dims),'DIMS',['{' arg.dims '}'],'MATTYPE',mattype(arg.type));
        elseif arg.isfull && ~arg.isreal
            a = parseTemplate(templates.ocomplexfull,'NAME',arg.name,'ARGNUM',ii-1,'TYPE',tt(arg.type),'NDIMS',countDims(arg.dims),'DIMS',['{' arg.dims '}'],'MATTYPE',mattype(arg.type));
        end
        
        outputstr = [outputstr a char(10)];
    end
    
    if opts.iocheck
        nargcheck = parseTemplate(templates.nargcheck,'NRHS',length(inputargs),'NLHS',length(outputargs));
    else
        nargcheck = '';
    end
    
    if isempty(strfind(cstring,char(10))) && cstring(end) ~= ';'
        %This is a C file
        cstring = parseTemplate(templates.main,'EXTRAINCLUDES',opts.extraincludes,...
                                               'INPUT',  tabulate(inputstr,TAB),...
                                               'OUTPUT', tabulate(outputstr,TAB),...
                                               'SNIPPET',sprintf('#include "%s"\n',cstring),...
                                               'NARGCHECK',tabulate(nargcheck,TAB));
    else
        %This is a snippet
        cstring = parseTemplate(templates.main,'EXTRAINCLUDES',opts.extraincludes,...
                                               'INPUT',  tabulate(inputstr,TAB),...
                                               'OUTPUT', tabulate(outputstr,TAB),...
                                               'SNIPPET',tabulate(cstring,TAB),...
                                               'NARGCHECK',tabulate(nargcheck,TAB));

    end
end

function b = escape(arg)
    b = strrep(strrep(arg,'\','\\'),'"','\"');
end

function nums = countDims(dims)
    nums = nnz(dims==',') + 1;
    if nums < 2
        error('An output argument must have two or more dimensions');
    end
end

function type = mattype(type)
    type = ['mx' upper(type) '_CLASS'];
end

function str = tabulate(str,tab)
    str = [tab,strrep(str,char(10),[char(10) tab])];
end
   
function [templates] = defineTemplates()
templates.main = verbatim;
%{
/* C file autogenerated by mexme.m */
#include <mex.h>
#include <math.h>
#include <matrix.h>
#include <stdlib.h>
#include <float.h>
#include <string.h>

#include "mexmetypecheck.cpp"

/* Your extra includes and function definitions here */
^EXTRAINCLUDES^

void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[] )
{

/*Input output boilerplate*/
^NARGCHECK^

^INPUT^

^OUTPUT^


/*Actual function*/
^SNIPPET^

}


%}

templates.nargcheck = verbatim;
%{
if(nlhs != ^NLHS^ || nrhs != ^NRHS^)
    mexErrMsgTxt("Function must be called with ^NRHS^ arguments and has ^NLHS^ return values");
%}

templates.typecheck = verbatim;
%{
mexmetypecheck(^NAME^_ptr,^MATTYPE^,"Argument ^NAME^ (#^ARGNUM^) is expected to be of type ^TYPE^");
%}

templates.cplxcheck = verbatim;
%{
if(!mxIsComplex(^NAME^_ptr))
    mexErrMsgTxt("Argument ^NAME^ (#^ARGNUM^) must be complex");
%}

templates.scalarcheck = verbatim;
%{
if(mxGetNumberOfElements(^NAME^_ptr) != 1)
    mexErrMsgTxt("Argument ^NAME^ (#^ARGNUM^) must be scalar");
%}

templates.extracheck = verbatim;
%{
if(!(^EXTRA_CHECK^))
    mexErrMsgTxt("Argument ^NAME^ (#^ARGNUM^) did not pass test \"^ESCAPED_EXTRA_CHECK^\"");
%}

templates.iscalar = verbatim;
%{
const mxArray *^NAME^_ptr = prhs[^ARGNUM^];
^TYPECHECK^
^SCALARCHECK^
const ^TYPE^   ^NAME^ = (^TYPE^) mxGetScalar(^NAME^_ptr);
%}

templates.irealfull = verbatim;
%{
const mxArray *^NAME^_ptr = prhs[^ARGNUM^];
^TYPECHECK^
const mwSize   ^NAME^_m = mxGetM(^NAME^_ptr);
const mwSize   ^NAME^_n = mxGetN(^NAME^_ptr);
const mwSize   ^NAME^_length = ^NAME^_m == 1 ? ^NAME^_n : ^NAME^_m;
const mwSize   ^NAME^_numel = mxGetNumberOfElements(^NAME^_ptr);
const mwSize   ^NAME^_ndims = mxGetNumberOfDimensions(^NAME^_ptr);
const mwSize  *^NAME^_size = mxGetDimensions(^NAME^_ptr);
const ^TYPE^  *^NAME^ = (^TYPE^ *) mxGetData(^NAME^_ptr);
%}

templates.icomplexfull = verbatim;
%{
const mxArray *^NAME^_ptr = prhs[^ARGNUM^];
^TYPECHECK^
^CPLXCHECK^
const mwSize   ^NAME^_m = mxGetM(^NAME^_ptr);
const mwSize   ^NAME^_n = mxGetN(^NAME^_ptr);
const mwSize   ^NAME^_length = ^NAME^_m == 1 ? ^NAME^_n : ^NAME^_m;
const mwSize   ^NAME^_numel = mxGetNumberOfElements(^NAME^_ptr);
const mwSize   ^NAME^_ndims = mxGetNumberOfDimensions(^NAME^_ptr);
const mwSize  *^NAME^_size = mxGetDimensions(^NAME^_ptr);
const ^TYPE^  *^NAME^_r = (^TYPE^ *) mxGetData(^NAME^_ptr);
const ^TYPE^  *^NAME^_i = (^TYPE^ *) mxGetImagData(^NAME^_ptr);
%}

templates.orealfull = verbatim;
%{
mwSize ^NAME^_dims[] = ^DIMS^;
plhs[^ARGNUM^] = mxCreateNumericArray(^NDIMS^,^NAME^_dims,^MATTYPE^,mxREAL);
mxArray **^NAME^_ptr = &plhs[^ARGNUM^];
^TYPE^   *^NAME^ = (^TYPE^ *) mxGetData(*^NAME^_ptr);
%}

templates.ocomplexfull = verbatim;
%{
mwSize ^NAME^_dims[] = ^DIMS^;
plhs[^ARGNUM^] = mxCreateNumericArray(^NDIMS^,^NAME^_dims,^MATTYPE^,mxCOMPLEX);
mxArray **^NAME^_ptr = &plhs[^ARGNUM^];
^TYPE^   *^NAME^_r = (^TYPE^ *) mxGetData(*^NAME^_ptr);
^TYPE^   *^NAME^_i = (^TYPE^ *) mxGetImagData(*^NAME^_ptr);
%}
end

Contact us