Main Content

Define Custom F-Beta Score Metric Object

Note

This topic explains how to define custom deep learning metric objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more information, see Define Custom Metric Function.

In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.

 How To Decide Which Metric Type To Use

If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a function handle, then you can define your own custom metric object using this topic as a guide. After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions function.

To define a custom deep learning metric class, you can use the template in this example, which takes you through these steps:

  1. Name the metric — Give the metric a name so that you can use it in MATLAB®.

  2. Declare the metric properties — Specify the public and private properties of the metric.

  3. Create a constructor function — Specify how to construct the metric and set default values.

  4. Create an initialization function (optional) — Specify how to initialize variables and run validation checks.

  5. Create a reset function — Specify how to reset the metric properties between iterations.

  6. Create an update function — Specify how to update metric properties between iterations.

  7. Create an aggregation function — Specify how to aggregate the metric values across multiple instances of the metric object.

  8. Create an evaluation function — Specify how to calculate the metric value for each iteration.

This example shows how to create a custom F-beta score metric. This equation defines the metric:

Fβ=(1+β2) × True Positive((1+β2)×True Positive) + (β2×False Negative) + False Positive

To see the completed metric class definition, see Completed Metric.

Tip

If you need the F1 metric, then you can use the built-in F-score metric. For more information, see Metrics.

Metric Template

Copy the metric template into a new file in MATLAB. This template gives the structure of a metric class definition. It outlines:

  • The properties block for public metric properties. This block must contain the Name property.

  • The properties block for private metric properties. This block is optional.

  • The metric constructor function.

  • The optional initialize function.

  • The required reset, update, aggregate, and evaluate functions.

classdef myMetric < deep.Metric

    properties
        % (Required) Metric name.
        Name

        % Declare public metric properties here.

        % Any code can access these properties. Include here any properties
        % that you want to access or edit outside of the class.
    end

    properties (Access = private)
        % (Optional) Metric properties.

        % Declare private metric properties here.

        % Only members of the defining class can access these properties.
        % Include here properties that you do not want to edit outside
        % the class.
    end

    methods
        function metric = myMetric(args)
            % Create a myMetric object.
            % This function must have the same name as the class.

            % Define metric construction function here.
        end

        function metric = initialize(metric,batchY,batchT)
            % (Optional) Initialize metric.
            %
            % Use this function to initialize variables and run validation
            % checks.
            %
            % Inputs:
            %           metric - Metric to initialize
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Initialized metric
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric initialization function here.
        end

        function metric = reset(metric)
            % Reset metric properties.
            %
            % Use this function to reset the metric properties between
            % iterations.
            %
            % Input:
            %           metric - Metric containing properties to reset
            %
            % Output:
            %           metric - Metric with reset properties

            % Define metric reset function here.
        end

        function metric = update(metric,batchY,batchT)
            % Update metric properties.
            %
            % Use this function to update metric properties that you use to
            % compute the final metric value.
            %
            % Inputs:
            %           metric - Metric containing properties to update
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Metric with updated properties
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric update function here.
        end

        function metric = aggregate(metric,metric2)
            % Aggregate metric properties.
            %
            % Use this function to define how to aggregate properties from
            % multiple instances of the same metric object during parallel
            % training.
            %
            % Inputs:
            %           metric  - Metric containing properties to aggregate
            %           metric2 - Metric containing properties to aggregate
            %
            % Output:
            %           metric - Metric with aggregated properties
            %
            % Define metric aggregation function here.
        end

        function val = evaluate(metric)
            % Evaluate metric properties.
            %
            % Use this function to define how to use the metric properties
            % to compute the final metric value.
            %
            % Input:
            %           metric - Metric containing properties to use to
            %           evaluate the metric value
            %
            % Output:
            %           val - Evaluated metric value
            %
            % To return multiple metric values, replace val with val1,...
            % valN.

            % Define metric evaluation function here.
        end
    end
end

Metric Name

First, give the metric a name. In the first line of the class file, replace the existing name myMetric with fBetaMetric.

classdef fBetaMetric < deep.Metric
    ... 
end

Next, rename the myMetric constructor function (the first function in the methods section) so that it has the same name as the metric.

    methods
        function metric = fBetaMetric(args)
            ...
        end
    ...
    end

Save Metric

Save the metric class file in a new file with the name fBetaMetric and the .m extension. The file name must match the metric name. To use the metric, you must save the file in the current folder or in a folder on the MATLAB path.

Declare Properties

Declare the metric properties in the property sections. You can specify attributes in the class definition to customize the behavior of properties for specific purposes. This template defines two property types by setting their Access attribute. Use the Access attribute to control access to specific class properties.

  • properties — Any code can access these properties. This is the default properties block with the default property attributes. By default, the Access attribute is public.

  • properties (Access = private) — Only members of the defining class can access the property.

Declare Public Properties

Declare public properties by listing them in the properties section. This section must contain the Name property. This metric also contains the Beta property. Include Beta as a public property so that you can access it outside the class.

    properties
        % (Required) Metric name.
        Name

        % Beta value for F-beta score metric.
        Beta
    end

Declare Private Properties

Declare private properties by listing them in the properties (Access = private) section. This metric requires three properties to evaluate the value: true positives (TPs), false positives (FPs), and false negatives (FNs). Only the functions within the metric class require access to these values.

    properties (Access = private)
        % Define true positives (TPs), false positives (FPs), and false
        % negatives (FNs).
        TruePositives
        FalsePositives
        FalseNegatives
    end

Create Constructor Function

Create the function that constructs the metric and initializes the metric properties. If the software requires any variables to evaluate the metric value, then these variables must be inputs to the constructor function.

The F-beta score metric constructor function requires the Name, Beta, NetworkOutput, and Maximize arguments. These arguments are optional when you use the constructor to create a metric object. Specify an args input to the fBetaMetric function that corresponds to the optional name-value arguments. Add a comment to explain the syntax of the function.

        function metric = fBetaMetric(args)
            % metric = fBetaMetric creates an fBetaMetric metric object.

            % metric = fBetaMetric(Name=name,Beta=beta,NetworkOutput="out1",Maximize=1)
            % also specifies the optional Name and Beta options. By default,
            % the metric name is "FBeta" with a beta value of 1. By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 1 as the optimal value
            % occurs when the F-Score is maximized.

            ...            
        end

Next, set the default values for the metric properties. Parse the input arguments using an arguments block. Specify the default metric name as "FBeta", the default beta value as 1, and the default network output as []. The metric name appears in plots and verbose output.

        function metric = fBetaMetric(args)
            ...
         
            arguments
                args.Name = "FBeta"
                args.Beta = 1
                args.NetworkOutput = []
                args.Maximize = 1
            end
            ...
        end

Set the properties of the metric.

        function metric = fBetaMetric(args)
            ...
       
            % Set the metric name and beta value.
            metric.Name = args.Name;
            metric.Beta = args.Beta;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property. 
            metric.Maximize = args.Maximize;
        end

View the completed constructor function. With this constructor function, the command fBetaMetric(Name="fbeta",Beta=0.5) creates an F-beta score metric object with the name "fbeta" and a beta value of 0.5.

        function metric = fBetaMetric(args)
            % metric = fBetaMetric creates an fBetaMetric metric object.

            % metric = fBetaMetric(Name=name,Beta=beta,NetworkOutput="out1")
            % also specifies the optional Name and Beta options. By default,
            % the metric name is "FBeta" with a beta value of 1. By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 1 as the optimal value
            % occurs when the F-Score is maximized.

            arguments
                args.Name = "FBeta"
                args.Beta = 1
                args.NetworkOutput = []
                args.Maximize = 1
            end

            % Set the metric name and beta value.
            metric.Name = args.Name;
            metric.Beta = args.Beta;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property. 
            metric.Maximize = args.Maximize;
        end

Create Initialization Function

Create the optional function that initializes variables and runs validation checks. For this example, the metric does not need the initialize function, so you can delete it. For an example of an initialize function, see Initialization Function.

Create Reset Function

Create the function that resets the metric properties. The software calls this function before each iteration. For the F-beta score metric, reset the TP, FP, and FN values to zero at the start of each iteration.

        function metric = reset(metric)
            % metric = reset(metric) resets the metric properties.
            metric.TruePositives  = 0;
            metric.FalsePositives = 0;
            metric.FalseNegatives = 0;
        end

Create Update Function

Create the function that updates the metric properties that you use to compute the F-beta score value. The software calls this function in each training and validation mini-batch.

In the update function, define these steps:

  1. Find the maximum score for each observation. The maximum score corresponds to the predicted class for each observation.

  2. Find the TP, FP, and FN values.

  3. Add the batch TP, FP, and FN values to the running total number of TPs, FPs, and FNs.

        function metric = update(metric,batchY,batchT)
            % metric = update(metric,batchY,batchT) updates the metric
            % properties.

            % Find the channel (class) dimension.
            cDim = finddim(batchY,"C");

            % Find the maximum score, which corresponds to the predicted
            % class. Set the predicted class to 1 and all other classes to 0.
            batchY = batchY == max(batchY,[],cDim);

            % Find the TP, FP, and FN values for this batch.
            batchTruePositives = sum(batchY & batchT, 2);
            batchFalsePositives = sum(batchY & ~batchT, 2);
            batchFalseNegatives = sum(~batchY & batchT, 2);

            % Add the batch values to the running totals and update the metric
            % properties.
            metric.TruePositives = metric.TruePositives + batchTruePositives;
            metric.FalsePositives = metric.FalsePositives + batchFalsePositives;
            metric.FalseNegatives = metric.FalseNegatives + batchFalseNegatives;
        end

Create Aggregation Function

Create the function that specifies how to combine the metric values and properties across multiple instances of the metric. For example, the aggregate function defines how to aggregate properties from multiple instances of the same metric object during parallel training.

For this example, to combine the TP, FP, and FN values, add the values from each metric instance.

        function metric = aggregate(metric,metric2)
            % metric = aggregate(metric,metric2) aggregates the metric
            % properties across two instances of the metric.

            metric.TruePositives = metric.TruePositives + metric2.TruePositives;
            metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives;
            metric.FalseNegatives = metric.FalseNegatives + metric2.FalseNegatives;
        end

Create Evaluation Function

Create the function that specifies how to compute the metric value in each iteration. This equation defines the F-beta score metric as:

Fβ=(1+β2) × True Positive((1+β2)×True Positive) + (β2×False Negative) + False Positive

Implement this equation in the evaluate function.

        function val = evaluate(metric)
            % val = evaluate(metric) uses the properties in metric to return the
            % evaluated metric value.

            % Extract TP, FP, and FN values.
            tp = metric.TruePositives;
            fp = metric.FalsePositives;
            fn = metric.FalseNegatives;

            % F-beta score = (1+Beta^2)*TP/((1+Beta^2)*TP+(Beta^2)*FN+FP)
            betaSq = metric.Beta^2;
            betaSqAddOneTP = (1+betaSq).*tp;

            % Compute F-beta score.
            val = mean(betaSqAddOneTP./(betaSqAddOneTP+betaSq.*fn+fp+eps));
        end

As the denominator value of this metric can be zero, add eps to the denominator to prevent the metric returning a NaN value.

Completed Metric

View the completed metric class file.

Note

For more information about when the software calls each function in the class, see Function Call Order.

classdef fBetaMetric < deep.Metric

    properties
        % (Required) Metric name.
        Name

        % Beta value for F-beta score metric.
        Beta
    end

    properties (Access = private)
        % Define true positives (TPs), false positives (FPs), and false
        % negatives (FNs).
        TruePositives
        FalsePositives
        FalseNegatives
    end

    methods
        function metric = fBetaMetric(args)
            % metric = fBetaMetric creates an fBetaMetric metric object.

            % metric = fBetaMetric(Name=name,Beta=beta,NetworkOutput="out1")
            % also specifies the optional Name and Beta options. By default,
            % the metric name is "FBeta" with a beta value of 1. By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 1 as the optimal value
            % occurs when the F-Score is maximized.

            arguments
                args.Name = "FBeta"
                args.Beta = 1
                args.NetworkOutput = []
                args.Maximize = true
            end

            % Set the metric name and beta value.
            metric.Name = args.Name;
            metric.Beta = args.Beta;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property. 
            metric.Maximize = args.Maximize;
        end

        function metric = reset(metric)
            % metric = reset(metric) resets the metric properties.
            metric.TruePositives  = 0;
            metric.FalsePositives = 0;
            metric.FalseNegatives = 0;
        end

        function metric = update(metric,batchY,batchT)
            % metric = update(metric,batchY,batchT) updates the metric
            % properties.

            % Find the channel (class) dimension.
            cDim = finddim(batchY,"C");

            % Find the maximum score, which corresponds to the predicted
            % class. Set the predicted class to 1 and all other classes to 0.
            batchY = batchY == max(batchY,[],cDim);

            % Find the TP, FP, and FN values for this batch.
            batchTruePositives = sum(batchY & batchT, 2);
            batchFalsePositives = sum(batchY & ~batchT, 2);
            batchFalseNegatives = sum(~batchY & batchT, 2);

            % Add the batch values to the running totals and update the metric
            % properties.
            metric.TruePositives = metric.TruePositives + batchTruePositives;
            metric.FalsePositives = metric.FalsePositives + batchFalsePositives;
            metric.FalseNegatives = metric.FalseNegatives + batchFalseNegatives;
        end

        function metric = aggregate(metric,metric2)
            % metric = aggregate(metric,metric2) aggregates the metric
            % properties across two instances of the metric.

            metric.TruePositives = metric.TruePositives + metric2.TruePositives;
            metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives;
            metric.FalseNegatives = metric.FalseNegatives + metric2.FalseNegatives;
        end

        function val = evaluate(metric)
            % val = evaluate(metric) uses the properties in metric to return the
            % evaluated metric value.

            % Extract TP, FP, and FN values.
            tp = metric.TruePositives;
            fp = metric.FalsePositives;
            fn = metric.FalseNegatives;

            % F-beta score = (1+Beta^2)*TP/((1+Beta^2)*TP+(Beta^2)*FN+FP)
            betaSq = metric.Beta^2;
            betaSqAddOneTP = (1+betaSq).*tp;

            % Compute F-beta score.
            val = mean(betaSqAddOneTP./(betaSqAddOneTP+betaSq.*fn+fp+eps));
        end
    end
end

Use Custom Metric During Training

You can use a custom metric in the same way as any other metric in Deep Learning Toolbox™. This section shows how to create and train a network for digit classification and track the F-beta score with a beta value of 0.5.

Unzip the digit sample data and create an image datastore. The imageDatastore function automatically labels the images based on folder names.

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

Use a subset of the data as the validation set.

numTrainingFiles = 750;
[imdsTrain,imdsVal] = splitEachLabel(imds,numTrainingFiles,"randomize");

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    fullyConnectedLayer(10)
    softmaxLayer];

Create an fBetaMetric object.

metric = fBetaMetric(Beta=0.5)
metric = 
  fBetaMetric with properties:

             Name: "FBeta"
             Beta: 0.5000
    NetworkOutput: []
         Maximize: 1

Specify the F-beta metric in the training options. To plot the metric during training, set Plots to "training-progress". To output the values during training, set Verbose to true. Return the network that achieves the best F-score value.

options = trainingOptions("adam", ...
    MaxEpochs=5, ...
    Metrics=metric, ...
    ValidationData=imdsVal, ...
    ValidationFrequency=50, ...
    Verbose=true, ...
    Plots="training-progress", ...
    ObjectiveMetricName="FBeta", ...
    OutputNetwork="best-validation");

Train the network using the trainnet function. The values for the training and validation sets appear in the plot.

net = trainnet(imdsTrain,layers,"crossentropy",options);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    ValidationLoss    TrainingFBeta    ValidationFBeta
    _________    _____    ___________    _________    ____________    ______________    _____________    _______________
            0        0       00:00:04        0.001                            13.488                            0.056862
            1        1       00:00:04        0.001          13.974                           0.037272                   
           50        1       00:00:17        0.001          2.7424            2.7448          0.68657            0.67177
          100        2       00:00:26        0.001          1.2965            1.2235          0.77512            0.79866
          150        3       00:00:34        0.001         0.64661           0.80412          0.87846            0.85387
          200        4       00:00:43        0.001         0.18627           0.53273          0.94685            0.89384
          250        5       00:00:52        0.001         0.16763           0.49371          0.94531            0.89466
          290        5       00:01:00        0.001         0.25976           0.39347          0.95055            0.91411
Training stopped: Max epochs completed

See Also

| |

Related Topics