Contents

Semantic Segmentation Using FCN-AlexNet

本プログラムでは、MATLAB上でFCN-AlexNetを構築・学習し、 学習済みネットワークを評価するところまでのワークフローを試行します。 画像データはCamVidデータセットを利用します。詳細についてはReference[1]をご覧ください。

This example shows how to create, train and evaluate FCN-AlexNet.

This example uses the CamVid dataset [1] from the University of Cambridge for training.

clear all, close all, clc;

Setup

FCN-AlexNetはAlexNetをベースとしたネットワークであるため、AlexNetを使えるように 予めサポートパッケージがインストールされている必要があります。 インストール手順等につきましてはヘルプドキュメントをご覧ください。(>>doc('alexnet'))

alexnet();

FCN-AlexNetのトレーニングにはGPUの利用を強く推奨します。 (GPUの利用にはParallel Computing Toolbox™が必要です。)

Download CamVid Dataset

以下のURLよりCamVidデータセットをダウンロードします。500MB以上のサイズがありますので、 ディスク容量にご注意ください。

imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';

outputFolder = fullfile(tempdir, 'CamVid');

if ~exist(outputFolder, 'dir')
    disp('Downloading 557 MB CamVid dataset...');

    unzip(imageURL, fullfile(outputFolder,'images'));
    unzip(labelURL, fullfile(outputFolder,'labels'));
end

Load CamVid Images

ダウンロードしたCamVidデータセットに含まれる画像をロードします。 巨大なデータセットですが、imageDatastoreを利用してワークスペースを効率的に利用します。

imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);

画像を1枚だけ読み込んで表示させます。

I = readimage(imds, 1);
I = histeq(I);
figure
imshow(I)

Load CamVid Pixel-Labeled Images

今度はラベル画像を読み込みます。pixelLabelDatastoreを利用します。 CamVidデータセットは32個のクラスを持っていますが、ここでは2クラスに纏めて利用します。

classes = [
    "Sky"
    "OtherObjects"
    ];

32クラスを2クラス('Sky'と'その他')に纏めます。'Car', 'Truck_Bus', 'Train'などのクラスは すべて'その他'とします。

labelIDs = camvidPixelLabelIDs();

上で定義した2クラスとIDを利用し、pixelLabelDatastoreを作成します。

labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

ラベル画像を読み込み、該当する画像データにオーバーレイ表示させてみます。

C = readimage(pxds, 1);

cmap = camvidColorMap;
B = labeloverlay(I,C,'ColorMap',cmap);

figure
imshow(B)
pixelLabelColorbar(cmap,classes);
% 何も色が付いていない領域はラベルを持っていない領域となり、
%  ネットワークのトレーニングには利用されません。

Analyze Dataset Statistics

各クラスラベルの分布を見るために、countEachLabelを利用します。 この関数を使うことで、クラス毎の総ピクセル数を確認することが出来ます。

tbl = countEachLabel(pxds)
tbl =

  2×3 table

         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    'Sky'             7.6801e+07    4.8315e+08     
    'OtherObjects'    3.9316e+08    4.8453e+08     

分布をFigureで可視化

frequency = tbl.PixelCount/sum(tbl.PixelCount);

figure
bar(1:numel(classes),frequency)
xticks(1:numel(classes))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')

理想的には、各クラスの総ピクセル数が等しくなっているのが望ましいですが、一般的には そのようなケースは稀となります。一部のクラスの総ピクセル数のみ極端に少ないような場合、 その領域が正しく検出できなくともAccuracyに大きく影響を与えないことになってしまいますので、 重み付けをしてから学習させます。

Resize CamVid Data

CamVidデータセットの画像サイズは720x960です。学習に要する時間やメモリ使用量を 削減するため、360x480にリサイズします。

imageFolder = fullfile(outputFolder,'imagesReszed',filesep);
imds = resizeCamVidImages(imds,imageFolder);

labelFolder = fullfile(outputFolder,'labelsResized',filesep);
pxds = resizeCamVidPixelLabels(pxds,labelFolder);

Prepare Training and Test Sets

データセットの60%を学習に利用し、残りをテスト用とします。画像及びラベル画像を6:4に 分割します。

[imdsTrain, imdsTest, pxdsTrain, pxdsTest] = partitionCamVidData(imds,pxds);

分類した結果、学習用画像とテスト画像の総数がそれぞれどの程度あるか確認します。

numTrainingImages = numel(imdsTrain.Files)
numTestingImages = numel(imdsTest.Files)
numTrainingImages =

   421


numTestingImages =

   280

Create the Network

FCN-AlexNetを作成します。Computer Vision System ToolboxにはfcnLayersという関数があり 手軽にFCNを作成できるようになっていますが、この関数を利用して作成できるFCNはVGG-16を ベースとしたFCNとなっています。AlexNetをベースとしたFCNを構築したい場合、 各レイヤとネットワークアーキテクチャをLayerGraphオブジェクトで定義していく必要があります。

imageSize = [360 480];
numClasses = numel(classes);

まず、ベースとなるAlexNetをロードし、今回利用する画像データのサイズ[360 480] にあわせてimageInputLayerを定義します。

net = alexnet();

layers = net.Layers;

layers(1) = imageInputLayer([imageSize 3], 'Name', layers(1).Name,...
    'DataAugmentation', layers(1).DataAugmentation, ...
    'Normalization', layers(1).Normalization);

FCNの特徴でもありますが、全結合層を畳み込み層に置き換えます。 Weights/BiasはAlexNetから引き継ぎます

% fc6 is layers 17
idx = 17;
weights = layers(idx).Weights';
weights = reshape(weights, 6, 6, 256, 4096);
bias = reshape(layers(idx).Bias, 1, 1, []);

layers(idx) = convolution2dLayer(6, 4096, 'NumChannels', 256, 'Name', 'fc6');
layers(idx).Weights = weights;
layers(idx).Bias = bias;

% fc7 is layers 20
idx = 20;
weights = layers(idx).Weights';
weights = reshape(weights, 1, 1, 4096, 4096);
bias = reshape(layers(idx).Bias, 1, 1, []);

layers(idx) = convolution2dLayer(1, 4096, 'NumChannels', 4096, 'Name', 'fc7');
layers(idx).Weights = weights;
layers(idx).Bias = bias;

ネットワークの出力が入力画像と整合することを確実にするため、 [100 100]でパディングを行います。 (ネットワークの出力をCropして最終的に入力画像と該当箇所を一致させるため、 マージンを持たせます。)

conv1 = layers(2);
conv1New = convolution2dLayer(conv1.FilterSize, conv1.NumFilters, ...
    'Stride', conv1.Stride, ...
    'Padding', [100 100], ...
    'NumChannels', conv1.NumChannels, ...
    'WeightLearnRateFactor', conv1.WeightLearnRateFactor, ...
    'WeightL2Factor', conv1.WeightL2Factor, ...
    'BiasLearnRateFactor', conv1.BiasLearnRateFactor, ...
    'BiasL2Factor', conv1.BiasL2Factor, ...
    'Name', conv1.Name);
conv1New.Weights = conv1.Weights;
conv1New.Bias = conv1.Bias;

layers(2) = conv1New;

AlexNetからClassification Layerを削除します。

layers(end-2:end) = [];

畳み込み&Pooling層を経て得られた特徴マップは解像度が落ちているため、 アップサンプルするための転置畳み込み層を定義します。

upscore = transposedConv2dLayer(64, numClasses, ...
    'NumChannels', numClasses, 'Stride', 32, 'Name', 'upscore');

作成した各層を接続していきます。

layers = [
    layers
    convolution2dLayer(1, numClasses, 'Name', 'score_fr');
    upscore
    crop2dLayer('centercrop', 'Name', 'score')
    softmaxLayer('Name', 'softmax')
    pixelClassificationLayer('Name', 'pixelLabels')
    ];

lgraph = layerGraph(layers);

% imageInputLayerの出力をCropLayerにも接続します。
lgraph = connectLayers(lgraph, 'data', 'score/ref');

Balance Classes Using Class Weighting

前述した通り各クラスの総ピクセル数が不均一であるため、総ピクセル数に応じた 重み付けを行います。

imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq
classWeights =

    3.0523
    0.5980

定義した重みをpixelClassificationLayerに反映させます。

pxLayer = pixelClassificationLayer('Name','labels','ClassNames', tbl.Name, 'ClassWeights', classWeights)
pxLayer = 

  PixelClassificationLayer のプロパティ:

            Name: 'labels'
      ClassNames: {'Sky'  'OtherObjects'}
    ClassWeights: [2×1 double]
      OutputSize: 'auto'

   Hyperparameters
    LossFunction: 'crossentropyex'

FCN-AlexNetのpixelClassificationLayerを更新します。現在のpixelClassificationLayerを 削除し、新しく定義した層を追加します。

lgraph = removeLayers(lgraph, 'pixelLabels');
lgraph = addLayers(lgraph, pxLayer);
lgraph = connectLayers(lgraph, 'softmax' ,'labels');
%これで、FCN-AlexNetの完成です。
figure, plot(lgraph)

Select Training Options

学習時のオプションを指定します。勾配法はMomentum-SGDを利用します。 利用するGPUに搭載されたメモリサイズに応じて、MiniBatchSizeを調整してください。

options = trainingOptions('sgdm', ...
    'Momentum', 0.9, ...
    'InitialLearnRate', 1e-3, ...
    'L2Regularization', 0.0005, ...
    'MaxEpochs', 30, ...
    'MiniBatchSize', 4, ...
    'Shuffle', 'every-epoch', ...
    'Plots','training-progress', ...
    'VerboseFrequency', 100);

Data Augmentation

学習に利用できる画像データの数が限られているため、Augmentation(数増し)を行います。 ここでは、ランダムに画像を左右反転させたり、X/Y方向に+/- 10ピクセルの範囲で平行移動させます。

augmenter = imageDataAugmenter('RandXReflection',true,...
    'RandXTranslation', [-10 10], 'RandYTranslation',[-10 10]);

Start Training

pixelLabelImageSourceを利用して、最終的に学習に用いるデータを定義します。 augmentationもここで行われます。

datasource = pixelLabelImageSource(imdsTrain,pxdsTrain,...
    'DataAugmentation',augmenter);

学習開始

[net, info] = trainNetwork(datasource,lgraph,options);
Training on single GPU.
Initializing image normalization.
|=========================================================================================|
|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|
|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |
|=========================================================================================|
|            1 |            1 |         3.04 |       0.6933 |       50.23% |       0.0010 |
|            1 |          100 |        23.02 |       0.6932 |       51.81% |       0.0010 |
|            2 |          200 |        42.41 |       0.6932 |       49.35% |       0.0010 |
|            3 |          300 |        61.68 |       0.6929 |       54.82% |       0.0010 |
|            4 |          400 |        81.01 |       0.6931 |       57.22% |       0.0010 |
|            5 |          500 |       100.02 |       0.6922 |       55.60% |       0.0010 |
|            6 |          600 |       119.05 |       0.6898 |       63.78% |       0.0010 |
|            7 |          700 |       138.41 |       0.6753 |       67.43% |       0.0010 |
|            8 |          800 |       157.36 |       0.4187 |       79.26% |       0.0010 |
|            9 |          900 |       176.37 |       0.1588 |       90.08% |       0.0010 |
|           10 |         1000 |       195.37 |       0.1311 |       90.39% |       0.0010 |
|           11 |         1100 |       214.29 |       0.1011 |       92.04% |       0.0010 |
|           12 |         1200 |       232.77 |       0.0955 |       93.00% |       0.0010 |
|           13 |         1300 |       251.28 |       0.1020 |       90.66% |       0.0010 |
|           14 |         1400 |       269.75 |       0.1129 |       91.74% |       0.0010 |
|           15 |         1500 |       288.23 |       0.1015 |       91.47% |       0.0010 |
|           16 |         1600 |       306.65 |       0.0828 |       92.68% |       0.0010 |
|           17 |         1700 |       325.16 |       0.1174 |       91.66% |       0.0010 |
|           18 |         1800 |       343.68 |       0.0801 |       92.56% |       0.0010 |
|           19 |         1900 |       362.20 |       0.0929 |       93.75% |       0.0010 |
|           20 |         2000 |       380.75 |       0.0729 |       89.97% |       0.0010 |
|           20 |         2100 |       399.32 |       0.0813 |       93.57% |       0.0010 |
|           21 |         2200 |       417.92 |       0.0577 |       94.47% |       0.0010 |
|           22 |         2300 |       436.48 |       0.0824 |       90.45% |       0.0010 |
|           23 |         2400 |       455.04 |       0.0754 |       93.00% |       0.0010 |
|           24 |         2500 |       473.58 |       0.0627 |       92.86% |       0.0010 |
|           25 |         2600 |       492.15 |       0.0708 |       93.09% |       0.0010 |
|           26 |         2700 |       511.03 |       0.0399 |       90.13% |       0.0010 |
|           27 |         2800 |       529.62 |       0.0746 |       92.07% |       0.0010 |
|           28 |         2900 |       548.25 |       0.0710 |       91.81% |       0.0010 |
|           29 |         3000 |       566.97 |       0.0737 |       86.36% |       0.0010 |
|           30 |         3100 |       585.67 |       0.0556 |       93.86% |       0.0010 |
|           30 |         3150 |       595.03 |       0.0656 |       91.65% |       0.0010 |
|=========================================================================================|

Test Network on One Image

テスト用画像を1枚読み込み、結果を表示させます。

idx = 147;
I = readimage(imdsTest,idx);
C = semanticseg(I, net);

% imshowで可視化
B = labeloverlay(I, C, 'Colormap', cmap, 'Transparency',0.4);
figure, imshowpair(I, B, 'montage')
pixelLabelColorbar(cmap, classes);

セグメンテーションの結果と、真値(Ground truth)を比較してみます。 重ね書きした結果、緑やマゼンタになっている箇所が真値とは異なる箇所になります。

expectedResult = readimage(pxdsTest,idx);
actual = uint8(C);
expected = uint8(expectedResult);
imshowpair(actual, expected)

FCNでセグメンテーションされた各クラスに対し、真値(Ground truth)領域が どの程度含まれているかを評価します。 これはIoUと呼ばれる指標で、jaccard関数を利用して測ることができます。

iou = jaccard(C, expectedResult);
table(classes,iou)
ans =

  2×2 table

       classes          iou  
    ______________    _______

    "Sky"             0.69784
    "OtherObjects"    0.92314

真値に対する類似度を表現する係数としては、Jaccardの他にも Dice係数やBF係数などが良く用いられます。 MATLAB上ではそれぞれdice関数、bfscore関数を使って求めることができます。

Evaluate Trained Network

テスト用に切り分けておいた全ての画像を用い、FCN-AlexNetの精度を測定します。

pxdsResults = semanticseg(imdsTest,net,'WriteLocation',tempdir,'Verbose',false);

semanticseg関数はpixelLabelDatastoreオブジェクトを返します。(ラベル画像のデータストア) 真値ラベルはpxdsTestデータストアにありますので、この2つのデータをevaluateSemanticSegmentation関数に与え、 各種評価指標を求めます。

metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTest,'Verbose',false);

データセット全体に対する指標(平均精度、平均IoU値等)

metrics.DataSetMetrics
ans =

  1×5 table

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

    0.96641           0.9721          0.8676     0.9101         0.53113    

各クラスに対する精度やIoU値

metrics.ClassMetrics
ans =

  2×3 table

                    Accuracy      IoU      MeanBFScore
                    ________    _______    ___________

    Sky              0.9804     0.80563    0.47428    
    OtherObjects    0.96381     0.92956    0.58653    

References

[1] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla. "Semantic object classes in video: A high-definition ground truth database." Pattern Recognition Letters Vol 30, Issue 2, 2009, pp 88-97.

Supporting Functions

function labelIDs = camvidPixelLabelIDs()
% Return the label IDs corresponding to each class.
%
% The CamVid dataset has 32 classes. Group them into 2 classes
% CamVidデータセットは32のクラスを有していますが、2クラス("Sky"と"それ以外")に纏めます。
labelIDs = { ...

    % "Sky"
    [
    128 128 128; ... % "Sky"
    ]

    % "OtherObjects"
    [
    000 128 064; ... % "Bridge"
    128 000 000; ... % "Building"
    064 192 000; ... % "Wall"
    064 000 064; ... % "Tunnel"
    192 000 128; ... % "Archway"
    192 192 128; ... % "Column_Pole"
    000 000 064; ... % "TrafficCone"
    128 064 128; ... % "Road"
    128 000 192; ... % "LaneMkgsDriv"
    192 000 064; ... % "LaneMkgsNonDriv"
    000 000 192; ... % "Sidewalk"
    064 192 128; ... % "ParkingBlock"
    128 128 192; ... % "RoadShoulder"
    128 128 000; ... % "Tree"
    192 192 000; ... % "VegetationMisc"
    192 128 128; ... % "SignSymbol"
    128 128 064; ... % "Misc_Text"
    000 064 064; ... % "TrafficLight"
    064 064 128; ... % "Fence"
    064 000 128; ... % "Car"
    064 128 192; ... % "SUVPickupTruck"
    192 128 192; ... % "Truck_Bus"
    192 064 128; ... % "Train"
    128 064 064; ... % "OtherMoving"
    064 064 000; ... % "Pedestrian"
    192 128 064; ... % "Child"
    064 000 192; ... % "CartLuggagePram"
    064 128 064; ... % "Animal"
    000 128 192; ... % "Bicyclist"
    192 000 192; ... % "MotorcycleScooter"
    ]

    };
end
function pixelLabelColorbar(cmap, classNames)
%分類するクラスに対応するcolorbarを追加します。
% Add a colorbar to the current axis. The colorbar is formatted
% to display the class names with the color.

colormap(gca,cmap)

% Add colorbar to current figure.
c = colorbar('peer', gca);

% Use class names for tick marks.
c.TickLabels = classNames;
numClasses = size(cmap,1);

% Center tick labels.
c.Ticks = 1/(numClasses*2):1/numClasses:1;

% Remove tick mark.
c.TickLength = 0;
end
function cmap = camvidColorMap()
% CamVidデータセットの各クラスに対して紐付けられる色(colormap)を定義します。
% Define the colormap used by CamVid dataset.

cmap = [
    60 40 222   % Sky
    128 0 0     % OtherObjects
    ];

% Normalize between [0 1].
cmap = cmap ./ 255;
end
function imds = resizeCamVidImages(imds, imageFolder)
% Resize images to [360 480].

if ~exist(imageFolder,'dir')
    mkdir(imageFolder)
else
    imds = imageDatastore(imageFolder);
    return; % Skip if images already resized
end

reset(imds)
while hasdata(imds)
    % Read an image.
    [I,info] = read(imds);

    % Resize image.
    I = imresize(I,[360 480]);

    % Write to disk.
    [~, filename, ext] = fileparts(info.Filename);
    imwrite(I,[imageFolder filename ext])
end

imds = imageDatastore(imageFolder);
end
function pxds = resizeCamVidPixelLabels(pxds, labelFolder)
% Resize pixel label data to [360 480].

classes = pxds.ClassNames;
labelIDs = 1:numel(classes);
if ~exist(labelFolder,'dir')
    mkdir(labelFolder)
else
    pxds = pixelLabelDatastore(labelFolder,classes,labelIDs);
    return; % Skip if images already resized
end

reset(pxds)
while hasdata(pxds)
    % Read the pixel data.
    [C,info] = read(pxds);

    % Convert from categorical to uint8.
    L = uint8(C);

    % Resize the data. Use 'nearest' interpolation to
    % preserve label IDs.
    L = imresize(L,[360 480],'nearest');

    % Write the data to disk.
    [~, filename, ext] = fileparts(info.Filename);
    imwrite(L,[labelFolder filename ext])
end

labelIDs = 1:numel(classes);
pxds = pixelLabelDatastore(labelFolder,classes,labelIDs);
end
function [imdsTrain, imdsTest, pxdsTrain, pxdsTest] = partitionCamVidData(imds,pxds)
% Partition CamVid data by randomly selecting 60% of the data for training. The
% rest is used for testing.

% Set initial random state for example reproducibility.
rng(0);
numFiles = numel(imds.Files);
shuffledIndices = randperm(numFiles);

% Use 60% of the images for training.
N = round(0.60 * numFiles);
trainingIdx = shuffledIndices(1:N);

% Use the rest for testing.
testIdx = shuffledIndices(N+1:end);

% Create image datastores for training and test.
trainingImages = imds.Files(trainingIdx);
testImages = imds.Files(testIdx);
imdsTrain = imageDatastore(trainingImages);
imdsTest = imageDatastore(testImages);

% Extract class and label IDs info.
classes = pxds.ClassNames;
labelIDs = 1:numel(pxds.ClassNames);

% Create pixel label datastores for training and test.
trainingLabels = pxds.Files(trainingIdx);
testLabels = pxds.Files(testIdx);
pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs);
pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs);
end