Global Patch Collider Demo

This sample trains the forest for the Global Patch Collider and stores output to the file "forest.yml.gz".

It then finds correspondences between two images using Global Patch Collider and calculates error using provided ground truth flow.

It will look for the file named "forest.yml.gz" with a learned forest. You can obtain the "forest.yml.gz" either by manually training it or by downloading one of the files trained on some publicly available dataset from here: https://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU

Sources:

Contents

1) Train

input training images

imgs1 = {
    fullfile(mexopencv.root(), 'test', 'RubberWhale1.png')
};
imgs2 = {
    fullfile(mexopencv.root(), 'test', 'RubberWhale2.png')
};
groundTruths = {
    fullfile(mexopencv.root(), 'test', 'RubberWhale.flo')
};
assert(isequal(numel(imgs1), numel(imgs2), numel(groundTruths)));

if exist(groundTruths{1}, 'file') ~= 2
    % attempt to download ground thruth flow from GitHub
    disp('Downloading FLO...')
    url = 'https://cdn.rawgit.com/opencv/opencv_extra/3.2.0/testdata/cv/optflow/RubberWhale.flo';
    urlwrite(url, groundTruths{1});
end

Global Patch Collider training paramters

params = {
    'MaxTreeDepth',20, ...      % Maximum tree depth to stop partitioning
    'MinNumberOfSamples',3, ... % Minimum number of samples in the node to stop partitioning
    'DescriptorType','DCT', ... % Descriptor type. Set to DCT for quality, WHT for speed
    'PrintProgress',false       % Set to false for quiet mode, set to true to print progress
};
forestDumpPath = fullfile(tempdir(), 'forest.yml.gz');

train the forest for the Global Patch Collider and save it

if exist(forestDumpPath, 'file') ~= 2
    gpc = cv.GPCForest();
    tic
    gpc.train(imgs1, imgs2, groundTruths);
    toc
    gpc.save(forestDumpPath);
end

2) Evaluate

test images

fromPath = imgs1{1};
toPath = imgs2{1};
gtPath = groundTruths{1};

from = imread(fromPath);
to = imread(toPath);
flo = cv.readOpticalFlow(gtPath);

load pretrained forest

forest = cv.GPCForest();
assert(exist(forestDumpPath,'file') == 2, 'No file with a trained model');
forest.load(forestDumpPath);

find correspondences between two the images using GPC

tic
corresp = forest.findCorrespondences(from, to, 'UseOpenCL',false);
toc
fprintf('Found %d matches\n', numel(corresp));
Elapsed time is 4.856060 seconds.
Found 20848 matches

calculate error using provided ground truth flow

gtU = flo(:,:,1);
gtV = flo(:,:,2);
a = cat(1, corresp.first);
b = cat(1, corresp.second);
ind = sub2ind(size(gtU), a(:,2), a(:,1));
gtDisplacement = [gtU(ind) gtV(ind)];
c = a + gtDisplacement;

% check for correct flow vector
mask = all(isfinite(gtDisplacement) & (gtDisplacement < 1e9), 2);
a = a(mask,:);
b = b(mask,:);
c = c(mask,:);

err = mean(sqrt(sum((b - c).^2, 2)));
fprintf('Average endpoint error = %f px.\n', err);
Average endpoint error = 0.929796 px.

display flows as color images

clr = getFlowColor(b - a);
dispOut = zeros(size(from), 'single');
dispOut(:,:,3) = 1;
dispOut = cv.circle(dispOut, a, 3, 'Colors',clr, 'Thickness','Filled');
dispOut = cv.cvtColor(dispOut, 'HSV2RGB');

clr = getFlowColor(b - c, false, 32);
dispErr = zeros(size(from), 'single');
dispErr(:,:,3) = 1;
dispErr = cv.circle(dispErr, a, 3, 'Colors',clr, 'Thickness','Filled');
dispErr = cv.cvtColor(dispErr, 'HSV2RGB');

dispGT = getFlowColor([gtU(:) gtV(:)]);
dispGT = reshape(dispGT(:,1:3), [size(gtU) 3]);
dispGT = cv.cvtColor(dispGT, 'HSV2RGB');

show results

opts = {'FontScale',0.8, 'Color','k', 'LineType','AA'};
str = 'Sparse matching: Global Patch Collider';
dispOut = cv.putText(dispOut, str, [20 40], opts{:});
str = sprintf('Average EPE: %.2f', err);
dispOut = cv.putText(dispOut, str, [20 80], opts{:});
str = sprintf('Number of matches: %d', nnz(mask));
dispOut = cv.putText(dispOut, str, [20 120], opts{:});

figure(1), imshow(dispOut), title('Correspondences')
figure(2), imshow(dispErr), title('Error')
figure(3), imshow(dispGT), title('Ground Truth')

Helper function

function clr = getFlowColor(UV, logScale, scaleDown)
    if nargin < 2, logScale = true; end
    if nargin < 3, scaleDown = 5; end

    angle = (atan2(-UV(:,2), -UV(:,1)) + pi) * 180 / pi;
    angle(all(UV == 0, 2)) = 0;

    radius = sqrt(sum(UV.^2, 2));
    if logScale
        radius = log(radius + 1);
    end
    radius = radius ./ scaleDown;
    radius = min(radius, 1);

    clr = [angle radius];
    clr(:,3) = 1;
    clr(:,4) = 0;
end